From 22bbd5471987f17fd333bc07e157926dd9bb665c Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Fri, 21 Feb 2020 18:04:13 +0800 Subject: [PATCH] Add the support of fp16 in fusion_group (#22239) --- .../ir/fusion_group/code_generator.cc | 35 ++- .../ir/fusion_group/code_generator.h | 4 +- .../ir/fusion_group/code_generator_helper.h | 25 -- .../ir/fusion_group/code_generator_tester.cc | 281 ++++++++++-------- .../ir/fusion_group/cuda_resources.h | 82 +++++ .../ir/fusion_group/fusion_group_pass.cc | 37 +-- .../ir/fusion_group/fusion_group_pass.h | 2 +- .../fusion_group/fusion_group_pass_tester.cc | 10 + .../framework/ir/fusion_group/operation.cc | 23 +- .../framework/ir/fusion_group/subgraph.h | 46 ++- .../fluid/framework/ir/pass_tester_helper.h | 9 +- .../operators/fused/fusion_group_op.cu.cc | 7 +- paddle/fluid/platform/device_code.cc | 55 +++- paddle/fluid/platform/device_code.h | 4 +- .../fluid/tests/unittests/ir/CMakeLists.txt | 2 +- .../unittests/ir/test_ir_fusion_group.py | 109 ------- .../unittests/ir/test_ir_fusion_group_pass.py | 142 +++++++++ 17 files changed, 570 insertions(+), 303 deletions(-) create mode 100644 paddle/fluid/framework/ir/fusion_group/cuda_resources.h delete mode 100644 python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group.py create mode 100644 python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.cc b/paddle/fluid/framework/ir/fusion_group/code_generator.cc index 0f9ee83a411..b7a75d376a5 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h" +#include "paddle/fluid/framework/ir/fusion_group/cuda_resources.h" #include "paddle/fluid/framework/ir/fusion_group/operation.h" namespace paddle { @@ -27,13 +28,14 @@ CodeGenerator::CodeGenerator() { // Only support elementwise operations now. code_templates_.resize(1); - CodeTemplate elementwise_t(elementwise_cuda_template); + CodeTemplate elementwise_t(cuda_kernel_template_1d); code_templates_[0] = elementwise_t; } std::string CodeGenerator::Generate(SubGraph* subgraph) { std::vector expressions = ConvertToExpressions(subgraph); - return Generate(subgraph->GetFuncName(), expressions); + return Generate(subgraph->GetFuncName(), subgraph->GetDataType(), + expressions); } static bool HasInput(Node* n, std::string name) { @@ -100,9 +102,9 @@ std::vector CodeGenerator::ConvertToExpressions( // In order to get the right result of expression, we need to calculate and // store the expression as suffix Expressions using vector. std::string CodeGenerator::Generate( - std::string func_name, std::vector expressions) { + std::string func_name, std::string dtype, + const std::vector& expressions) { // TODO(liuyiqun): Check whether all expressions are elementwise operations. - std::string dtype = "float"; std::set input_ids = DistilInputIds(expressions); std::set output_ids = DistilOutputIds(expressions); @@ -111,6 +113,15 @@ std::string CodeGenerator::Generate( template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtype)); template_var.Add("compute_body", EmitComputeBody(expressions, input_ids, output_ids, dtype)); + + std::string predefined_cuda_functions; + if (dtype == "float") { + predefined_cuda_functions = predefined_cuda_functions_fp32; + } else if (dtype == "double") { + predefined_cuda_functions = predefined_cuda_functions_fp64; + } else if (dtype == "float16") { + predefined_cuda_functions = predefined_cuda_functions_fp16; + } return predefined_cuda_functions + code_templates_[0].Format(template_var); } @@ -173,9 +184,10 @@ std::string CodeGenerator::EmitComputeBody( std::string dtype) { std::ostringstream compute; std::unordered_set used; + std::string compute_dtype = (dtype == "float16") ? "float" : dtype; for (size_t i = 0; i < expressions.size(); i++) { VLOG(3) << DebugString(expressions[i]); - compute << expressions[i].GetExpression(dtype, &used); + compute << expressions[i].GetExpression(compute_dtype, &used); } // Load input to temporal variables. @@ -183,14 +195,23 @@ std::string CodeGenerator::EmitComputeBody( for (auto id : input_ids) { if (output_ids.find(id) == output_ids.end() && used.find(id) != used.end()) { - load << dtype << " " << TmpName(id) << " = " << ArgName(id) << "[idx];"; + if (dtype == "float16") { + load << "float " << TmpName(id) << " = __half2float(" << ArgName(id) + << "[idx]);"; + } else { + load << dtype << " " << TmpName(id) << " = " << ArgName(id) << "[idx];"; + } } } // Store temporal variables to memory. std::ostringstream store; for (auto id : output_ids) { - store << ArgName(id) << "[idx] = " << TmpName(id) << ";"; + if (dtype == "float16") { + store << ArgName(id) << "[idx] = __float2half(" << TmpName(id) << ");"; + } else { + store << ArgName(id) << "[idx] = " << TmpName(id) << ";"; + } } return load.str() + compute.str() + store.str(); diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.h b/paddle/fluid/framework/ir/fusion_group/code_generator.h index 22d66611182..ce1bcc48e65 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.h +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.h @@ -30,8 +30,8 @@ class CodeGenerator { public: CodeGenerator(); - std::string Generate(std::string func_name, - std::vector expressions); + std::string Generate(std::string func_name, std::string dtype, + const std::vector& expressions); std::string Generate(SubGraph* subgraph); diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h index 140e0d3a06b..5749755d3ab 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_helper.h @@ -149,31 +149,6 @@ class CodeTemplate { std::string template_str_; }; -static const char predefined_cuda_functions[] = R"( -__device__ float real_exp(float x) { return ::expf(x); } -__device__ double real_exp(double x) { return ::exp(x); } - -__device__ float real_log(float x) { return ::logf(x); } -__device__ double real_log(double x) { return ::log(x); } - -__device__ float real_min(float x, float y) { return ::fminf(x, y); } -__device__ double real_min(double x, double y) { return ::fmin(x, y); } - -__device__ float real_max(float x, float y) { return ::fmaxf(x, y); } -__device__ double real_max(double x, double y) { return ::fmax(x, y); } - -)"; - -static const char elementwise_cuda_template[] = R"( -extern "C" __global__ void $func_name($parameters) { - for(int idx = blockIdx.x * blockDim.x + threadIdx.x; - idx < N; - idx += gridDim.x * blockDim.x) { - $compute_body - } -} -)"; - static std::string DebugString(const OperationExpression& expr) { std::stringstream ret; ret << "Op(" << expr.GetOpType() << "), inputs:{"; diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc b/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc index a5409cb9d6a..8f4eb7443ff 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator_tester.cc @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/operators/math.h" #include "paddle/fluid/platform/device_code.h" +#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/init.h" #ifdef PADDLE_WITH_CUDA @@ -88,7 +89,8 @@ inline float elementwise_mul_grad_dy(float x, float y, float out, float dout) { void CheckOutput(const std::vector& expressions, const std::vector cpu_tensors, const std::vector input_ids_of_subgraph, - const std::vector output_ids_of_subgraph, int i) { + const std::vector output_ids_of_subgraph, int i, + float eps) { std::vector var(cpu_tensors.size()); for (auto id : input_ids_of_subgraph) { if (id >= 0) { @@ -138,7 +140,12 @@ void CheckOutput(const std::vector& expressions, for (auto id : output_ids_of_subgraph) { float actual = cpu_tensors[id].data()[i]; float expect = var[id]; - EXPECT_LT(fabs(actual - expect), 1.E-05); + if (fabs(actual - expect) > eps) { + LOG(INFO) << "Precision check failed from i = " << id + << ", expect: " << expect << ", actual: " << actual; + EXPECT_LT(fabs(actual - expect), eps); + break; + } } } @@ -162,33 +169,49 @@ void SetupRandomCPUTensor(LoDTensor* tensor) { namespace fusion_group = paddle::framework::ir::fusion_group; +template void TestMainImpl(std::string func_name, std::string code_str, std::vector cpu_tensors, int n, std::vector input_ids, std::vector output_ids) { + bool is_float16 = std::type_index(typeid(T)) == + std::type_index(typeid(paddle::platform::float16)); + paddle::framework::InitDevices(false, {0}); paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0); paddle::platform::CUDADeviceCode device_code(place, func_name, code_str); - device_code.Compile(); + device_code.Compile(is_float16); std::vector gpu_tensors(cpu_tensors.size()); + std::vector tmp_cpu_tensors(cpu_tensors.size()); - std::vector gpu_ptrs(gpu_tensors.size()); + std::vector gpu_ptrs(gpu_tensors.size()); std::vector args; args.push_back(&n); for (auto id : input_ids) { if (id >= 0) { gpu_ptrs[id] = - gpu_tensors[id].mutable_data(cpu_tensors[id].dims(), place); + gpu_tensors[id].mutable_data(cpu_tensors[id].dims(), place); fusion_group::SetupRandomCPUTensor(&cpu_tensors[id]); - TensorCopySync(cpu_tensors[id], place, &gpu_tensors[id]); + if (is_float16) { + paddle::platform::float16* tmp_cpu_ptr = + tmp_cpu_tensors[id].mutable_data( + cpu_tensors[id].dims(), paddle::platform::CPUPlace()); + const float* cpu_ptr = cpu_tensors[id].data(); + for (int64_t i = 0; i < cpu_tensors[id].numel(); ++i) { + tmp_cpu_ptr[i] = paddle::platform::float16(cpu_ptr[i]); + } + TensorCopySync(tmp_cpu_tensors[id], place, &gpu_tensors[id]); + } else { + TensorCopySync(cpu_tensors[id], place, &gpu_tensors[id]); + } args.push_back(&gpu_ptrs[id]); } } for (auto id : output_ids) { gpu_ptrs[id] = - gpu_tensors[id].mutable_data(cpu_tensors[id].dims(), place); + gpu_tensors[id].mutable_data(cpu_tensors[id].dims(), place); args.push_back(&gpu_ptrs[id]); } @@ -200,38 +223,93 @@ void TestMainImpl(std::string func_name, std::string code_str, paddle::platform::DeviceContextPool::Instance().Get(place)); dev_ctx->Wait(); + // Copy the results back to CPU. for (auto id : output_ids) { - TensorCopySync(gpu_tensors[id], paddle::platform::CPUPlace(), - &cpu_tensors[id]); + if (is_float16) { + paddle::platform::float16* tmp_cpu_ptr = + tmp_cpu_tensors[id].mutable_data( + cpu_tensors[id].dims(), paddle::platform::CPUPlace()); + TensorCopySync(gpu_tensors[id], paddle::platform::CPUPlace(), + &tmp_cpu_tensors[id]); + + float* cpu_ptr = cpu_tensors[id].mutable_data( + cpu_tensors[id].dims(), paddle::platform::CPUPlace()); + for (int64_t i = 0; i < cpu_tensors[id].numel(); ++i) { + cpu_ptr[i] = static_cast(tmp_cpu_ptr[i]); + } + } else { + TensorCopySync(gpu_tensors[id], paddle::platform::CPUPlace(), + &cpu_tensors[id]); + } + } +} + +void TestElementwiseMain( + std::string func_name, std::string code_str, + std::vector expressions, + std::vector input_ids, std::vector output_ids, + std::string dtype) { + std::unordered_set ids; + for (auto id : input_ids) { + ids.insert(id); + } + for (auto id : output_ids) { + ids.insert(id); + } + + // Prepare CPU tensors which always hold float. + std::vector cpu_tensors(ids.size()); + auto dims = paddle::framework::make_ddim( + {static_cast(256), static_cast(1024)}); + for (size_t i = 0; i < cpu_tensors.size(); ++i) { + cpu_tensors[i].mutable_data(dims, paddle::platform::CPUPlace()); + } + + int n = cpu_tensors[0].numel(); + if (dtype == "float16") { + TestMainImpl(func_name, code_str, cpu_tensors, n, + input_ids, output_ids); + } else { + TestMainImpl(func_name, code_str, cpu_tensors, n, input_ids, + output_ids); + } + + // Check the results + float eps = (dtype == "float16") ? 1E-2 : 1E-5; + for (int i = 0; i < n; i++) { + fusion_group::CheckOutput(expressions, cpu_tensors, input_ids, output_ids, + i, eps); } } void TestMain(std::string func_name, std::vector expressions, - std::vector cpu_tensors, int n, - std::vector input_ids, std::vector output_ids) { + std::vector input_ids, std::vector output_ids, + std::string dtype) { fusion_group::OperationMap::Init(); fusion_group::CodeGenerator code_generator; - std::string code_str = code_generator.Generate(func_name, expressions); + std::string code_str = code_generator.Generate(func_name, dtype, expressions); VLOG(3) << code_str; - TestMainImpl(func_name, code_str, cpu_tensors, n, input_ids, output_ids); + LOG(INFO) << "dtype: " << dtype; + TestElementwiseMain(func_name, code_str, expressions, input_ids, output_ids, + dtype); } -std::vector TestMain( - fusion_group::SubGraph* subgraph, - std::vector cpu_tensors, int n, - std::vector input_ids, std::vector output_ids) { +void TestMain(fusion_group::SubGraph* subgraph, std::vector input_ids, + std::vector output_ids) { fusion_group::OperationMap::Init(); fusion_group::CodeGenerator code_generator; std::string code_str = code_generator.Generate(subgraph); VLOG(3) << code_str; - TestMainImpl(subgraph->GetFuncName(), code_str, cpu_tensors, n, input_ids, - output_ids); - // Need to check the accuracy according to expressions. - return code_generator.ConvertToExpressions(subgraph); + std::vector expressions = + code_generator.ConvertToExpressions(subgraph); + + LOG(INFO) << "dtype: " << subgraph->GetDataType(); + TestElementwiseMain(subgraph->GetFuncName(), code_str, expressions, input_ids, + output_ids, subgraph->GetDataType()); } TEST(code_generator, elementwise) { @@ -248,30 +326,16 @@ TEST(code_generator, elementwise) { std::vector expressions = { exp1, exp2, exp3, exp4, exp5}; - // Prepare CPU tensors - std::vector cpu_tensors(9); - auto dims = paddle::framework::make_ddim( - {static_cast(256), static_cast(1024)}); - for (size_t i = 0; i < cpu_tensors.size(); ++i) { - cpu_tensors[i].mutable_data(dims, paddle::platform::CPUPlace()); - } - - // Expressions: - // Op(elementwise_mul), inputs:{0,1}, outputs:{2} - // Op(elementwise_add), inputs:{2,3}, outputs:{4} - // Op(elementwise_sub), inputs:{4,5}, outputs:{6} - // Op(relu), inputs:{6}, outputs:{7} - // Op(sigmoid), inputs:{7}, outputs:{8} - int n = cpu_tensors[0].numel(); - std::vector input_ids = {0, 1, 3, 5}; - std::vector output_ids = {2, 4, 6, 7, 8}; - TestMain("elementwise_kernel_0", expressions, cpu_tensors, n, input_ids, - output_ids); - - // Check the results - for (int i = 0; i < n; i++) { - fusion_group::CheckOutput(expressions, cpu_tensors, input_ids, output_ids, - i); + for (std::string dtype : {"float", "float16"}) { + // Expressions: + // Op(elementwise_mul), inputs:{0,1}, outputs:{2} + // Op(elementwise_add), inputs:{2,3}, outputs:{4} + // Op(elementwise_sub), inputs:{4,5}, outputs:{6} + // Op(relu), inputs:{6}, outputs:{7} + // Op(sigmoid), inputs:{7}, outputs:{8} + std::vector input_ids = {0, 1, 3, 5}; + std::vector output_ids = {2, 4, 6, 7, 8}; + TestMain("elementwise_kernel_0", expressions, input_ids, output_ids, dtype); } } @@ -286,32 +350,19 @@ TEST(code_generator, elementwise_grad) { {4, 5}); std::vector expressions = {exp1, exp2}; - // Prepare CPU tensors - std::vector cpu_tensors(8); - auto dims = paddle::framework::make_ddim( - {static_cast(256), static_cast(1024)}); - for (size_t i = 0; i < cpu_tensors.size(); ++i) { - cpu_tensors[i].mutable_data(dims, paddle::platform::CPUPlace()); - } - - // Expressions: - // Op(relu_grad), inputs:{2,3,7}, outputs:{6} - // Op(elementwise_mul_grad), inputs:{0,1,2,6}, outputs:{4,5} - int n = cpu_tensors[0].numel(); - std::vector input_ids = {0, 1, 2, 3, 7}; - std::vector output_ids = {4, 5, 6}; - TestMain("elementwise_grad_kernel_0", expressions, cpu_tensors, n, input_ids, - output_ids); - - // Check the results - for (int i = 0; i < n; i++) { - fusion_group::CheckOutput(expressions, cpu_tensors, input_ids, output_ids, - i); + for (std::string dtype : {"float", "float16"}) { + // Expressions: + // Op(relu_grad), inputs:{2,3,7}, outputs:{6} + // Op(elementwise_mul_grad), inputs:{0,1,2,6}, outputs:{4,5} + std::vector input_ids = {0, 1, 2, 3, 7}; + std::vector output_ids = {4, 5, 6}; + TestMain("elementwise_grad_kernel_0", expressions, input_ids, output_ids, + dtype); } } -std::unique_ptr BuildGraph( - bool backward = false) { +std::unique_ptr BuildGraph(bool backward, + std::string dtype) { // inputs operator output // -------------------------------------------------------- // x0 sigmoid -> tmp_0 @@ -353,6 +404,14 @@ std::unique_ptr BuildGraph( std::unique_ptr graph( new paddle::framework::ir::Graph(layers.main_program())); + auto proto_dtype = (dtype == "float16") + ? paddle::framework::proto::VarType::FP16 + : paddle::framework::proto::VarType::FP32; + for (auto* n : graph->Nodes()) { + if (n && n->IsVar() && n->Var()) { + n->Var()->SetDataType(proto_dtype); + } + } #ifdef __clang__ return graph; #else @@ -401,66 +460,40 @@ std::unordered_set DistilGradNodes( } TEST(code_generator, subgraph) { - std::unique_ptr graph = BuildGraph(false); - fusion_group::SubGraph subgraph(0, "elementwise_kernel_1", true, - graph->Nodes()); - - // Prepare CPU tensors - std::vector cpu_tensors(9); - auto dims = paddle::framework::make_ddim( - {static_cast(256), static_cast(1024)}); - for (size_t i = 0; i < cpu_tensors.size(); ++i) { - cpu_tensors[i].mutable_data(dims, paddle::platform::CPUPlace()); - } - - // Expressions generated by code_generator (they may be different): - // Op(sigmoid), inputs:{0}, outputs:{4} - // Op(elementwise_mul), inputs:{4,1}, outputs:{7} - // Op(tanh), inputs:{2}, outputs:{5} - // Op(elementwise_mul), inputs:{3,5}, outputs:{6} - // Op(elementwise_add), inputs:{7,6}, outputs:{8} - int n = cpu_tensors[0].numel(); - std::vector input_ids = {0, 1, 2, 3}; - std::vector output_ids = {4, 5, 6, 7, 8}; - std::vector expressions = - TestMain(&subgraph, cpu_tensors, n, input_ids, output_ids); - - // Check the results - for (int i = 0; i < n; i++) { - fusion_group::CheckOutput(expressions, cpu_tensors, input_ids, output_ids, - i); + for (std::string dtype : {"float", "float16"}) { + std::unique_ptr graph = + BuildGraph(false, dtype); + fusion_group::SubGraph subgraph(0, "elementwise_kernel_1", true, + graph->Nodes()); + + // Expressions generated by code_generator (they may be different): + // Op(sigmoid), inputs:{0}, outputs:{4} + // Op(elementwise_mul), inputs:{4,1}, outputs:{7} + // Op(tanh), inputs:{2}, outputs:{5} + // Op(elementwise_mul), inputs:{3,5}, outputs:{6} + // Op(elementwise_add), inputs:{7,6}, outputs:{8} + std::vector input_ids = {0, 1, 2, 3}; + std::vector output_ids = {4, 5, 6, 7, 8}; + TestMain(&subgraph, input_ids, output_ids); } } TEST(code_generator, subgraph_grad) { - std::unique_ptr graph = BuildGraph(true); - fusion_group::SubGraph subgraph(0, "elementwise_grad_kernel_1", true, - DistilGradNodes(graph)); - - // Prepare CPU tensors - std::vector cpu_tensors(18); - auto dims = paddle::framework::make_ddim( - {static_cast(256), static_cast(1024)}); - for (size_t i = 0; i < cpu_tensors.size(); ++i) { - cpu_tensors[i].mutable_data(dims, paddle::platform::CPUPlace()); - } - - // Expressions generated by code_generator (they may be different): - // Op(elementwise_add_grad), inputs:{1,2,3,0}, outputs:{11,10} - // Op(elementwise_mul_grad), inputs:{5,4,2,10}, outputs:{17,13} - // Op(elementwise_mul_grad), inputs:{7,6,1,11}, outputs:{12,15} - // Op(sigmoid_grad), inputs:{8,7,12}, outputs:{16} - // Op(tanh_grad), inputs:{9,4,13}, outputs:{14} - int n = cpu_tensors[0].numel(); - std::vector input_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - std::vector output_ids = {10, 11, 12, 13, 14, 15, 16, 17}; - std::vector expressions = - TestMain(&subgraph, cpu_tensors, n, input_ids, output_ids); - - // Check the results - for (int i = 0; i < n; i++) { - fusion_group::CheckOutput(expressions, cpu_tensors, input_ids, output_ids, - i); + for (std::string dtype : {"float", "float16"}) { + std::unique_ptr graph = + BuildGraph(true, dtype); + fusion_group::SubGraph subgraph(0, "elementwise_grad_kernel_1", true, + DistilGradNodes(graph)); + + // Expressions generated by code_generator (they may be different): + // Op(elementwise_add_grad), inputs:{1,2,3,0}, outputs:{11,10} + // Op(elementwise_mul_grad), inputs:{5,4,2,10}, outputs:{17,13} + // Op(elementwise_mul_grad), inputs:{7,6,1,11}, outputs:{12,15} + // Op(sigmoid_grad), inputs:{8,7,12}, outputs:{16} + // Op(tanh_grad), inputs:{9,4,13}, outputs:{14} + std::vector input_ids = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector output_ids = {10, 11, 12, 13, 14, 15, 16, 17}; + TestMain(&subgraph, input_ids, output_ids); } } #endif diff --git a/paddle/fluid/framework/ir/fusion_group/cuda_resources.h b/paddle/fluid/framework/ir/fusion_group/cuda_resources.h new file mode 100644 index 00000000000..e4382e205ba --- /dev/null +++ b/paddle/fluid/framework/ir/fusion_group/cuda_resources.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace framework { +namespace ir { +namespace fusion_group { + +static constexpr char predefined_cuda_functions_fp32[] = R"( +__device__ inline float real_exp(float x) { return ::expf(x); } +__device__ inline float real_log(float x) { return ::logf(x); } + +)"; + +static constexpr char predefined_cuda_functions_fp64[] = R"( +__device__ inline double real_exp(double x) { return ::exp(x); } +__device__ inline double real_log(double x) { return ::log(x); } + +)"; + +static constexpr char predefined_cuda_functions_fp16[] = R"( +__device__ inline float real_exp(float x) { return ::expf(x); } +__device__ inline float real_log(float x) { return ::logf(x); } + +#define __HALF_TO_US(var) *(reinterpret_cast(&(var))) +#define __HALF_TO_CUS(var) *(reinterpret_cast(&(var))) + +struct __align__(2) __half { + __device__ __half() { } + + protected: + unsigned short __x; +}; + +__device__ __half __float2half(const float f) { + __half val; + asm("{ cvt.rn.f16.f32 %0, %1; }\n" : "=h"(__HALF_TO_US(val) + +) : "f"(f)); + return val; +} + +__device__ float __half2float(const __half h) { + float val; + asm("{ cvt.f32.f16 %0, %1; }\n" : "=f"(val) : "h"(__HALF_TO_CUS(h))); + return val; +} + +#undef __HALF_TO_US +#undef __HALF_TO_CUS + +typedef __half float16; + +)"; + +static constexpr char cuda_kernel_template_1d[] = R"( +extern "C" __global__ void $func_name($parameters) { + for(int idx = blockIdx.x * blockDim.x + threadIdx.x; + idx < N; + idx += gridDim.x * blockDim.x) { + $compute_body + } +} +)"; + +} // namespace fusion_group +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc index 1d9d4ab5d23..787bfe58987 100644 --- a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc +++ b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc @@ -32,8 +32,7 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const { if (Get("use_gpu")) { fusion_group::OperationMap::Init(); int num_elementwise_groups = DetectFusionGroup(graph, 0); - VLOG(3) << "Detect " << num_elementwise_groups - << " elementwise fusion groups."; + AddStatis(num_elementwise_groups); } } @@ -49,23 +48,23 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const { size_t min_subgraph_size = 2; bool save_intermediate_out = true; for (auto& vec : subgraphs) { - if (vec.size() >= min_subgraph_size) { - std::string func_name = "fused_elementwise_" + std::to_string(index++); - fusion_group::SubGraph subgraph( - type, func_name, save_intermediate_out, - std::unordered_set(vec.begin(), vec.end())); - VLOG(3) << "subgraph: {\n" - << DebugString(subgraph.SortedNodes()) << "}\n"; - - GenerateCode(&subgraph); - InsertFusionGroupOp(graph, &subgraph); - num_subgraphs++; + fusion_group::SubGraph subgraph( + type, "", save_intermediate_out, + std::unordered_set(vec.begin(), vec.end())); + VLOG(3) << "subgraph: {\n" << DebugString(subgraph.SortedNodes()) << "}\n"; + + if (subgraph.IsValid(min_subgraph_size)) { + subgraph.SetFuncName("fused_elementwise_" + std::to_string(index++)); + if (GenerateCode(&subgraph)) { + InsertFusionGroupOp(graph, &subgraph); + num_subgraphs++; + } } } return num_subgraphs; } -void FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const { +bool FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const { fusion_group::CodeGenerator code_generator; std::string code_str = code_generator.Generate(subgraph); VLOG(3) << code_str; @@ -74,10 +73,12 @@ void FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const { platform::CUDAPlace place = platform::CUDAPlace(0); std::unique_ptr device_code( new platform::CUDADeviceCode(place, subgraph->GetFuncName(), code_str)); - device_code->Compile(); - - platform::DeviceCodePool& pool = platform::DeviceCodePool::Init({place}); - pool.Set(std::move(device_code)); + bool is_compiled = device_code->Compile(); + if (is_compiled) { + platform::DeviceCodePool& pool = platform::DeviceCodePool::Init({place}); + pool.Set(std::move(device_code)); + } + return is_compiled; } static int ExtractOpRole(fusion_group::SubGraph* subgraph) { diff --git a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h index 72c7250e720..3438783c180 100644 --- a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h +++ b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h @@ -29,7 +29,7 @@ class FusionGroupPass : public FusePassBase { private: int DetectFusionGroup(Graph* graph, int type = 0) const; - void GenerateCode(fusion_group::SubGraph* subgraph) const; + bool GenerateCode(fusion_group::SubGraph* subgraph) const; void InsertFusionGroupOp(Graph* graph, fusion_group::SubGraph* subgraph) const; diff --git a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass_tester.cc b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass_tester.cc index 2446716019c..de48c8772bf 100644 --- a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass_tester.cc +++ b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass_tester.cc @@ -59,6 +59,11 @@ std::unique_ptr BuildElementwiseListGraph(bool backward = false) { } std::unique_ptr graph(new Graph(layers.main_program())); + for (auto* n : graph->Nodes()) { + if (n && n->IsVar() && n->Var()) { + n->Var()->SetDataType(proto::VarType::FP32); + } + } #ifdef __clang__ return graph; #else @@ -116,6 +121,11 @@ std::unique_ptr BuildElementwiseTreeGraph(bool backward = false) { } std::unique_ptr graph(new Graph(layers.main_program())); + for (auto* n : graph->Nodes()) { + if (n && n->IsVar() && n->Var()) { + n->Var()->SetDataType(proto::VarType::FP32); + } + } #ifdef __clang__ return graph; #else diff --git a/paddle/fluid/framework/ir/fusion_group/operation.cc b/paddle/fluid/framework/ir/fusion_group/operation.cc index 912044611f6..966cc2752c0 100644 --- a/paddle/fluid/framework/ir/fusion_group/operation.cc +++ b/paddle/fluid/framework/ir/fusion_group/operation.cc @@ -91,7 +91,7 @@ void OperationMap::InsertUnaryElementwiseOperations() { // relu: // out = f(x) = x > 0 ? x : 0 // dx = dout * (out > 0 ? 1 : 0) - insert_handler("relu", "real_max(${0}, 0)", {"${1} > 0 ? ${2} : 0"}); + insert_handler("relu", "${0} > 0 ? ${0} : 0", {"${1} > 0 ? ${2} : 0"}); // sigmoid: // out = f(x) = 1.0 / (1.0 + exp(-x)) // dx = dout * out * (1 - out) @@ -133,9 +133,24 @@ void OperationMap::InsertBinaryElementwiseOperations() { // dy = dout * x insert_handler("elementwise_mul", "${0} * ${1}", {"${3} * ${1}", "${3} * ${0}"}); - insert_handler("elementwise_div", "${0} / ${1}", {}); - insert_handler("elementwise_min", "real_min(${0}, ${1})", {}); - insert_handler("elementwise_max", "real_max(${0}, ${1})", {}); + // elementwise_div: + // out = x / y + // dx = dout / y + // dy = - dout * out / y + insert_handler("elementwise_div", "${0} / ${1}", + {"${3} / ${1}", "- ${3} * ${2} / ${1}"}); + // elementwise_min: + // out = x < y ? x : y + // dx = dout * (x < y) + // dy = dout * (x >= y) + insert_handler("elementwise_min", "${0} < ${1} ? ${0} : ${1}", + {"${3} * (${0} < ${1})", "${3} * (${0} >= ${1})"}); + // elementwise_max: + // out = x > y ? x : y + // dx = dout * (x > y) + // dy = dout * (x <= y) + insert_handler("elementwise_max", "${0} > ${1} ? ${0} : ${1}", + {"${3} * (${0} > ${1})", "${3} * (${0} <= ${1})"}); } } // namespace fusion_group diff --git a/paddle/fluid/framework/ir/fusion_group/subgraph.h b/paddle/fluid/framework/ir/fusion_group/subgraph.h index b9810882e1c..35247ece490 100644 --- a/paddle/fluid/framework/ir/fusion_group/subgraph.h +++ b/paddle/fluid/framework/ir/fusion_group/subgraph.h @@ -49,11 +49,23 @@ class SubGraph { } } } + ExtractDataType(); } - bool IsEmpty() { return nodes_set_.empty(); } + bool IsValid(int min_subgraph_size) { + int num_operations = GetNumOperations(); + if (num_operations < min_subgraph_size) { + VLOG(2) << "There are only " << num_operations + << " operations in the subgraph. Expected at least " + << min_subgraph_size; + return false; + } + + return ExtractDataType(); + } int GetType() const { return type_; } + std::string GetDataType() const { return data_type_; } void SetFuncName(std::string func_name) { func_name_ = func_name; } std::string GetFuncName() const { return func_name_; } @@ -150,6 +162,37 @@ class SubGraph { } private: + bool ExtractDataType() { + bool is_first = true; + proto::VarType::Type data_type = proto::VarType::FP32; + for (auto* n : nodes_set_) { + if (n && n->IsVar() && n->Var()) { + if (n->Var()->GetType() != proto::VarType::LOD_TENSOR) { + // All var node in a subgraph should hold a LoDTensor. + return false; + } + if (is_first) { + data_type = n->Var()->GetDataType(); + is_first = false; + } else if (n->Var()->GetDataType() != data_type) { + // DataType of VarDesc in a subgraph is not the same. + return false; + } + } + } + if (data_type == proto::VarType::FP32) { + data_type_ = "float"; + } else if (data_type == proto::VarType::FP64) { + data_type_ = "double"; + } else if (data_type == proto::VarType::FP16) { + data_type_ = "float16"; + } else { + VLOG(2) << "Only support fp32, fp64 and fp16 in fusion_group."; + return false; + } + return true; + } + void TopologicalSort() { if (!is_sorted_) { std::unordered_map> inputs_map; @@ -203,6 +246,7 @@ class SubGraph { private: int type_{-1}; + std::string data_type_; std::string func_name_; bool save_intermediate_out_{true}; diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 82f9e726613..d8595ad51a8 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -33,8 +33,9 @@ struct Layers { const ProgramDesc& main_program() { return program_; } VarDesc* data(std::string name, std::vector shape = {}, - bool is_persistable = false) { - return lod_tensor(name, shape, is_persistable); + bool is_persistable = false, + proto::VarType::Type data_type = proto::VarType::FP32) { + return lod_tensor(name, shape, is_persistable, data_type); } VarDesc* conv2d(VarDesc* input, VarDesc* filter, VarDesc* bias, @@ -379,9 +380,11 @@ struct Layers { private: VarDesc* lod_tensor(std::string name, std::vector shape = {}, - bool is_persistable = false) { + bool is_persistable = false, + proto::VarType::Type data_type = proto::VarType::FP32) { auto* var = program_.MutableBlock(0)->Var(name); var->SetType(proto::VarType::LOD_TENSOR); + var->SetDataType(data_type); var->SetShape(shape); var->SetPersistable(is_persistable); return var; diff --git a/paddle/fluid/operators/fused/fusion_group_op.cu.cc b/paddle/fluid/operators/fused/fusion_group_op.cu.cc index 63c243beafb..94949f56331 100644 --- a/paddle/fluid/operators/fused/fusion_group_op.cu.cc +++ b/paddle/fluid/operators/fused/fusion_group_op.cu.cc @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fused/fusion_group_op.h" +#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( - fusion_group, - ops::FusionGroupKernel, - ops::FusionGroupKernel); + fusion_group, ops::FusionGroupKernel, + ops::FusionGroupKernel, + ops::FusionGroupKernel); diff --git a/paddle/fluid/platform/device_code.cc b/paddle/fluid/platform/device_code.cc index b26de68395d..4f13f8e3889 100644 --- a/paddle/fluid/platform/device_code.cc +++ b/paddle/fluid/platform/device_code.cc @@ -13,11 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/device_code.h" +#include #include #include #include #include "paddle/fluid/platform/enforce.h" +DECLARE_string(cuda_dir); + namespace paddle { namespace platform { @@ -79,6 +82,46 @@ DeviceCodePool::DeviceCodePool(const std::vector& places) { } #ifdef PADDLE_WITH_CUDA +static std::string FindCUDAIncludePath() { + auto EndWith = [](std::string str, std::string substr) -> bool { + size_t pos = str.rfind(substr); + return pos != std::string::npos && pos == (str.length() - substr.length()); + }; + + struct stat st; + std::string cuda_include_path; + if (!FLAGS_cuda_dir.empty()) { + cuda_include_path = FLAGS_cuda_dir; + if (EndWith(cuda_include_path, "/")) { + cuda_include_path.erase(cuda_include_path.end() - 1); + } + for (std::string suffix : {"/lib", "/lib64"}) { + if (EndWith(FLAGS_cuda_dir, suffix)) { + cuda_include_path.erase(cuda_include_path.end() - suffix.length()); + break; + } + } + + if (!EndWith(cuda_include_path, "include")) { + cuda_include_path += "/include"; + } + // Whether the cuda_include_path exists on the file system. + if (stat(cuda_include_path.c_str(), &st) == 0) { + return cuda_include_path; + } + } + + cuda_include_path = "/usr/local/cuda/include"; + if (stat(cuda_include_path.c_str(), &st) == 0) { + return cuda_include_path; + } + LOG(WARNING) << "Cannot find CUDA include path." + << "Please check whether CUDA is installed in the default " + "installation path, or specify it by export " + "FLAGS_cuda_dir=xxx."; + return ""; +} + CUDADeviceCode::CUDADeviceCode(const Place& place, const std::string& name, const std::string& kernel) { if (!is_gpu_place(place)) { @@ -91,7 +134,7 @@ CUDADeviceCode::CUDADeviceCode(const Place& place, const std::string& name, kernel_ = kernel; } -bool CUDADeviceCode::Compile() { +bool CUDADeviceCode::Compile(bool include_path) { is_compiled_ = false; if (!dynload::HasNVRTC() || !dynload::HasCUDADriver()) { LOG(WARNING) @@ -116,8 +159,14 @@ bool CUDADeviceCode::Compile() { int compute_capability = dev_ctx->GetComputeCapability(); std::string compute_flag = "--gpu-architecture=compute_" + std::to_string(compute_capability); - const std::vector options = {"--std=c++11", - compute_flag.c_str()}; + std::vector options = {"--std=c++11", compute_flag.c_str()}; + if (include_path) { + std::string cuda_include_path = FindCUDAIncludePath(); + if (!cuda_include_path.empty()) { + std::string include_option = "--include-path=" + cuda_include_path; + options.push_back(include_option.c_str()); + } + } nvrtcResult compile_result = dynload::nvrtcCompileProgram(program, // program options.size(), // numOptions diff --git a/paddle/fluid/platform/device_code.h b/paddle/fluid/platform/device_code.h index 2895c568b6e..38520754406 100644 --- a/paddle/fluid/platform/device_code.h +++ b/paddle/fluid/platform/device_code.h @@ -31,7 +31,7 @@ namespace platform { class DeviceCode { public: virtual ~DeviceCode() {} - virtual bool Compile() = 0; + virtual bool Compile(bool include_path = false) = 0; virtual void Launch(const size_t n, std::vector* args) const = 0; Place GetPlace() const { return place_; } @@ -48,7 +48,7 @@ class CUDADeviceCode : public DeviceCode { public: explicit CUDADeviceCode(const Place& place, const std::string& name, const std::string& kernel); - bool Compile() override; + bool Compile(bool include_path = false) override; void Launch(const size_t n, std::vector* args) const override; void SetNumThreads(int num_threads) { num_threads_ = num_threads; } diff --git a/python/paddle/fluid/tests/unittests/ir/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/CMakeLists.txt index dbf6537dc53..9ecddac3a01 100644 --- a/python/paddle/fluid/tests/unittests/ir/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/CMakeLists.txt @@ -2,7 +2,7 @@ file(GLOB TEST_IR_PASSES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") string(REPLACE ".py" "" TEST_IR_PASSES "${TEST_IR_PASSES}") if(NOT WITH_GPU OR WIN32 OR APPLE) - LIST(REMOVE_ITEM TEST_IR_PASSES test_ir_fusion_group) + LIST(REMOVE_ITEM TEST_IR_PASSES test_ir_fusion_group_pass) endif() foreach(target ${TEST_IR_PASSES}) diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group.py b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group.py deleted file mode 100644 index 99181cbdc02..00000000000 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -from pass_test import PassTest -import paddle.fluid as fluid -import paddle.fluid.core as core - - -class FusionGroupPassTest(PassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data1 = fluid.data(name="data1", shape=[32, 128], dtype="float32") - data2 = fluid.data(name="data2", shape=[32, 128], dtype="float32") - data3 = fluid.data(name="data3", shape=[32, 128], dtype="float32") - tmp_1 = fluid.layers.elementwise_add(data1, data2) - tmp_2 = fluid.layers.elementwise_mul(data3, tmp_1) - - self.feeds = { - "data1": np.random.random((32, 128)).astype("float32"), - "data2": np.random.random((32, 128)).astype("float32"), - "data3": np.random.random((32, 128)).astype("float32") - } - self.fetch_list = [tmp_1, tmp_2] - self.pass_names = "fusion_group_pass" - self.fused_op_type = "fusion_group" - self.num_fused_ops = 1 - - def test_check_output(self): - use_gpu_set = [] - if core.is_compiled_with_cuda(): - use_gpu_set.append(True) - for use_gpu in use_gpu_set: - self.pass_attrs = {"fusion_group_pass": {"use_gpu": use_gpu}} - place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() - self.check_output_with_place(place, startup_on_cpu=False) - - -class FusionGroupPassTest1(FusionGroupPassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data = [] - for i in range(5): - data.append( - fluid.data( - name=("data" + str(i)), - shape=[32, 128], - dtype="float32")) - tmp_1 = ( - fluid.layers.assign(data[0]) * fluid.layers.sigmoid(data[1]) - ) + (fluid.layers.sigmoid(data[2]) * fluid.layers.tanh(data[3])) - tmp_2 = fluid.layers.tanh(tmp_1) + fluid.layers.sigmoid(data[4]) - - self.feeds = {} - for i in range(5): - self.feeds["data" + str(i)] = np.random.random( - (32, 128)).astype("float32") - - self.fetch_list = [tmp_1, tmp_2] - self.pass_names = "fusion_group_pass" - self.fused_op_type = "fusion_group" - self.num_fused_ops = 1 - - -class FusionGroupPassTest2(FusionGroupPassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data = [] - for i in range(3): - data.append( - fluid.data( - name=("data" + str(i)), - shape=[32, 128], - dtype="float32")) - data.append( - fluid.data( - name="data3", shape=[128, 32], dtype="float32")) - tmp_1 = fluid.layers.relu((data[0] - data[1]) * data[2]) - tmp_2 = fluid.layers.sigmoid(data[3]) - tmp_3 = fluid.layers.relu(tmp_2) - tmp_4 = fluid.layers.mul(tmp_1, tmp_3) - - self.feeds = {} - for i in range(3): - self.feeds["data" + str(i)] = np.random.random( - (32, 128)).astype("float32") - self.feeds["data3"] = np.random.random((128, 32)).astype("float32") - - self.fetch_list = [tmp_1, tmp_2, tmp_3, tmp_4] - self.pass_names = "fusion_group_pass" - self.fused_op_type = "fusion_group" - self.num_fused_ops = 2 - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py new file mode 100644 index 00000000000..e0121e08eff --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_fusion_group_pass.py @@ -0,0 +1,142 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from pass_test import PassTest +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core + + +class FusionGroupPassTest(PassTest): + def build_program(self, dtype): + with fluid.program_guard(self.main_program, self.startup_program): + self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 2) + self.feed_vars.append( + fluid.data( + name="data2", shape=[128, 128], dtype=dtype)) + + # subgraph with only 1 op node + tmp_0 = self.feed_vars[0] * self.feed_vars[1] + tmp_1 = layers.mul(tmp_0, self.feed_vars[2]) + # subgraph with 2 op nodes + tmp_2 = layers.relu(tmp_0 + tmp_1) + + self.fetch_list = [tmp_2] + self.num_fused_ops = 1 + + def setUp(self): + self.build_program("float32") + self.feeds = self._feed_random_data(self.feed_vars) + self.pass_names = "fusion_group_pass" + self.fused_op_type = "fusion_group" + + def _prepare_feed_vars(self, shape, dtype, num_data): + feed_vars = [] + for i in range(num_data): + var = fluid.data(name=("data" + str(i)), shape=shape, dtype=dtype) + feed_vars.append(var) + return feed_vars + + def _feed_random_data(self, feed_vars): + feeds = {} + for var in feed_vars: + if var.type != fluid.core.VarDesc.VarType.LOD_TENSOR: + raise TypeError("Feed data of non LoDTensor is not supported.") + + shape = var.shape + if var.dtype == fluid.core.VarDesc.VarType.FP32: + dtype = "float32" + elif var.dtype == fluid.core.VarDesc.VarType.FP64: + dtype = "float64" + elif var.dtype == fluid.core.VarDesc.VarType.FP16: + dtype = "float16" + else: + raise ValueError("Unsupported dtype %s" % var.dtype) + feeds[var.name] = np.random.random(shape).astype(dtype) + return feeds + + def test_check_output(self): + if core.is_compiled_with_cuda(): + self.pass_attrs = {"fusion_group_pass": {"use_gpu": True}} + self.check_output_with_place(fluid.CUDAPlace(0)) + + +class FusionGroupPassTest1(FusionGroupPassTest): + def build_program(self, dtype): + with fluid.program_guard(self.main_program, self.startup_program): + self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 5) + + tmp_0 = layers.assign(self.feed_vars[0]) + # subgraph with 9 op nodes + tmp_1 = tmp_0 * layers.sigmoid(self.feed_vars[1]) + layers.sigmoid( + self.feed_vars[2]) * layers.tanh(self.feed_vars[3]) + tmp_2 = layers.tanh(tmp_1) + layers.sigmoid(self.feed_vars[4]) + + self.fetch_list = [tmp_1, tmp_2] + self.num_fused_ops = 1 + + +class FusionGroupPassTest2(FusionGroupPassTest): + def build_program(self, dtype): + with fluid.program_guard(self.main_program, self.startup_program): + self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 3) + self.feed_vars.append( + fluid.data( + name="data3", shape=[128, 32], dtype=dtype)) + + # subgraph with 3 op nodes + tmp_1 = layers.relu( + (self.feed_vars[0] - self.feed_vars[1]) * self.feed_vars[2]) + # subgraph with 2 op nodes + tmp_2 = layers.relu(layers.sigmoid(self.feed_vars[3])) + tmp_3 = layers.mul(tmp_1, tmp_2) + + self.fetch_list = [tmp_1, tmp_2, tmp_3] + self.num_fused_ops = 2 + + +class FusionGroupPassTestFP64(FusionGroupPassTest): + def setUp(self): + self.build_program("float64") + self.feeds = self._feed_random_data(self.feed_vars) + self.pass_names = "fusion_group_pass" + self.fused_op_type = "fusion_group" + + +class FusionGroupPassTestFP16(FusionGroupPassTest): + def build_program(self, dtype): + with fluid.program_guard(self.main_program, self.startup_program): + self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 2) + self.feed_vars.append( + fluid.data( + name="data2", shape=[128, 128], dtype=dtype)) + + # subgraph with only 1 op node + tmp_0 = self.feed_vars[0] * self.feed_vars[1] + tmp_1 = layers.mul(tmp_0, self.feed_vars[2]) + tmp_2 = layers.cast(tmp_0, dtype="float16") + tmp_3 = layers.cast(tmp_1, dtype="float16") + # subgraph with 2 op nodes + tmp_4 = layers.relu(tmp_2 + tmp_3) + tmp_5 = layers.cast(tmp_4, dtype=dtype) + + self.fetch_list = [tmp_5] + self.num_fused_ops = 1 + + +if __name__ == "__main__": + unittest.main() -- GitLab