diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 9476256b0f0e5ac2290a814e73374fb1552ff5c2..88acc28470290cf1b513403eee3a93c92db2af0c 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -110,7 +110,9 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") cc_library(pass_builder SRCS pass_builder.cc DEPS pass) -cc_test(codegen_test SRCS codegen_test.cc DEPS codegen_helper codegen) +if(WITH_GPU) + cc_test(codegen_test SRCS codegen_test.cc DEPS codegen_helper codegen device_code lod_tensor) +endif() cc_test(node_test SRCS node_test.cc DEPS node) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) diff --git a/paddle/fluid/framework/ir/codegen.cc b/paddle/fluid/framework/ir/codegen.cc index c3e5efccba570192453d4336ea36a9a550e5be4d..60a5ff224a943d735b005159bc9d77d253359ffa 100644 --- a/paddle/fluid/framework/ir/codegen.cc +++ b/paddle/fluid/framework/ir/codegen.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,76 +19,15 @@ namespace paddle { namespace framework { namespace ir { -// we get the parameter list code for the expression information -std::string CodeGen::GetDeclarationCode( - std::vector expression) { - std::stringstream ret; - ret << "fuse_kernel"; - ret << R"((int N )"; - std::set input_ids; - std::set output_ids; - std::vector last_output_idis; - - for (size_t i = 0; i < expression.size(); i++) { - std::vector tmp_input = expression[i].GetInputIds(); - for (size_t j = 0; j < tmp_input.size(); j++) { - int id = tmp_input[j]; - input_ids.insert(id); - } - int tmp_output = expression[i].GetOutputId(); - output_ids.insert(tmp_output); - } - - std::set::iterator it = input_ids.begin(); - while (it != input_ids.end()) { - int var_index = *it; - if (output_ids.find(var_index) != output_ids.end()) { - input_ids.erase(it++); - } else { - it++; - } - } - - for (it = input_ids.begin(); it != input_ids.end(); it++) { - int var_index = *it; - ret << R"(, const T* var)" << var_index; - } - - for (it = output_ids.begin(); it != output_ids.end(); it++) { - int var_index = *it; - ret << R"(, T* var)" << var_index; - } - - ret << R"())"; - - return ret.str(); -} - -std::string CodeGen::GetOffsetCode() { - std::stringstream ret; - ret << indentation << "int offset = idx;" << std::endl; - return ret.str(); +CodeGenerator::CodeGenerator(CodeTemplate code_template) { + code_template_ = code_template; } -std::string CodeGen::GetComputeCode( - std::vector expression) { - // get the right experssion code using suffix expression - std::stringstream ret; - for (size_t i = 0; i < expression.size(); i++) { - ret << expression[i].GetExpression(); - } - return ret.str(); -} // in order to get the right result of expression, we need to calculate, we // store the expression as // suffix Expressions using vector -std::string CodeGen::GetKernelCode( - std::vector expression) { - auto declaration_code = GetDeclarationCode(expression); - auto offset_code = GetOffsetCode(); - auto compute_code = GetComputeCode(expression); - auto cuda_kernel = const_kernel_start + declaration_code + const_kernel_mid + - offset_code + compute_code + const_kernel_end; +std::string CodeGenerator::GenerateCode(TemplateVariable template_var) { + auto cuda_kernel = kernel_function + code_template_.Format(template_var); return cuda_kernel; } } // namespace ir diff --git a/paddle/fluid/framework/ir/codegen.h b/paddle/fluid/framework/ir/codegen.h index 975d48885e72a3b6f6aa5cf89fa943118593834e..2cf61ada48e727c4c8bbb0643997950ea0c5ada0 100644 --- a/paddle/fluid/framework/ir/codegen.h +++ b/paddle/fluid/framework/ir/codegen.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,16 +20,14 @@ namespace paddle { namespace framework { namespace ir { -class CodeGen { +class CodeGenerator { public: - std::string GetKernelCode(std::vector expression); + explicit CodeGenerator(CodeTemplate code_template); + std::string GenerateCode(TemplateVariable template_var); + // TODO(wangchao66) std::string GenerateCode(const Graph& graph) private: - std::string GetDeclarationCode( - std::vector expression); - std::string GetOffsetCode(); - std::string GetComputeCode( - std::vector expression); + CodeTemplate code_template_; }; } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/codegen_helper.cc b/paddle/fluid/framework/ir/codegen_helper.cc index 8f14549eb717835063bba66503c269729ca2773d..5e0b14253c1d125823278c25947cec62be39d521 100644 --- a/paddle/fluid/framework/ir/codegen_helper.cc +++ b/paddle/fluid/framework/ir/codegen_helper.cc @@ -21,41 +21,46 @@ namespace framework { namespace ir { OperationExpression::OperationExpression(std::vector input_ids, - int output_id, - std::string search_operation) { + int output_id, std::string op) { input_ids_ = input_ids; output_id_ = output_id; - search_operation_ = search_operation; + op_ = op; } +std::string OperationExpression::GetRHSTemplate() { + std::stringstream ret; + std::string rhs_end = ";"; + auto rhs = support_table[op_]; + for (size_t i = 0; i < input_ids_.size(); i++) { + auto replaced_str = replaced_element_in_order[i]; + auto pos = rhs.find(replaced_str); + auto index = input_ids_[i]; + rhs.replace(pos, replaced_str.length(), std::to_string(index) + R"([idx])"); + } + ret << rhs << rhs_end; + return ret.str(); +} + +std::string OperationExpression::GetLHSTemplate() { + std::stringstream ret; + ret << "var" << output_id_ << R"([idx] = )"; + return ret.str(); +} + +bool OperationExpression::SupportState() { + return (support_table.find(op_) == support_table.end()); +} // we Traverse the graph and get the group , all input id and output id is // unique for the node which belong the group std::string OperationExpression::GetExpression() { std::stringstream ret; - if (operator_cuda_table.find(search_operation_) == - operator_cuda_table.end()) { - std::cerr << "Not supportted operation, " << search_operation_ << std::endl; - } else { - auto rhs = operator_cuda_table[search_operation_]; - std::string replaced_str = "$"; - int count = 0; - auto pos = rhs.find(replaced_str); - while (pos != -1) { - auto index = input_ids_[count]; - rhs.replace(pos, replaced_str.length(), - std::to_string(index) + R"([offset])"); - pos = rhs.find(replaced_str); - count++; - } - auto lhs = std::string(indentation) + "var" + std::to_string(output_id_) + - R"([offset])"; - auto equal_split = R"( = )"; - auto semicolon = R"(;)"; - ret << lhs << equal_split << rhs << semicolon << std::endl; + if (!SupportState()) { + ret << GetLHSTemplate() << GetRHSTemplate(); } return ret.str(); } + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/codegen_helper.h b/paddle/fluid/framework/ir/codegen_helper.h index be8d3c8ac26fcde9e8964475709d604822c70688..fbc59c4349042428608e98e94a8ffa698f2baabc 100644 --- a/paddle/fluid/framework/ir/codegen_helper.h +++ b/paddle/fluid/framework/ir/codegen_helper.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once #include +#include +#include #include #include #include @@ -21,50 +23,218 @@ limitations under the License. */ namespace paddle { namespace framework { namespace ir { -static std::unordered_map operator_cuda_table = { - {"elementwise_add", "var$ + var$"}, - {"elementwise_sub", "var$ - var$"}, - {"elementwise_mul", "var$ * var$"}, - {"elementwise_div", "var$ / var$"}, - {"elementwise_min", "real_min(var$, var$)"}, - {"elementwise_max", "real_max(var$, var$)"}, - {"relu", "real_max(var$, 0)"}, - {"sigmoid", "1.0 / (1.0 + real_exp(-var$))"}}; +static std::vector replaced_element_in_order = {"@", "$"}; + +static std::vector kernel_template = {"$name", "$parameter", + "$compute"}; + +static std::unordered_map support_table = { + {"elementwise_add", "var@ + var$"}, + {"elementwise_sub", "var@ - var$"}, + {"elementwise_mul", "var@ * var$"}, + {"elementwise_div", "var@ / var$"}, + {"elementwise_min", "real_min(var@, var$)"}, + {"elementwise_max", "real_max(var@, var$)"}, + {"relu", "real_max(var@, 0)"}, + {"sigmoid", "1.0 / (1.0 + real_exp(-var@))"}}; + +// Paddle elementwise op consist the broacast op and elementwise op // op computation is composed by single or many operation +// here we only generate the simple expression code so we +// make it simple class OperationExpression { public: OperationExpression(std::vector input_ids, int output_id, - std::string search_oprtation); + std::string op); std::string GetExpression(); std::vector GetInputIds() { return input_ids_; } int GetOutputId() { return output_id_; } + bool SupportState(); + // in oreder to make offset more flexible we add stride and basic offset + std::string GetRHSTemplate(); + std::string GetLHSTemplate(); private: std::vector input_ids_; int output_id_; - std::string search_operation_; + std::string op_; }; -static const char indentation[] = R"( )"; +class TemplateVariable { + public: + void Add(std::string identifier, std::string expression) { + strings_[identifier] = expression; + } + void Remove(std::string identifier, std::string expression) { + for (auto it = strings_.begin(); it != strings_.end();) { + if (it->first == identifier) { + it = strings_.erase(it); + } else { + it++; + } + } + } + + std::unordered_map Get() { return strings_; } + + private: + std::unordered_map strings_; +}; +class CodeTemplate { + public: + CodeTemplate() = default; + explicit CodeTemplate(std::string template_str) { + template_str_ = template_str; + } + + std::string Format(TemplateVariable template_var) { + std::string ret = template_str_; + std::unordered_map identifier_str = + template_var.Get(); + + for (size_t i = 0; i < ret.size(); i++) { + auto pos = i; + char c = ret[pos]; + + if (c == '$') { + for (size_t j = 0; j < kernel_template.size(); j++) { + int template_size = kernel_template[j].size(); + auto tmp_cmp = ret.substr(pos, template_size); + if (tmp_cmp == kernel_template[j]) { + ret.replace(pos, template_size, identifier_str[kernel_template[j]]); + } + } + } + } + + return EmitIndents(ret); + } + std::string EmitIndents(std::string str) { + std::string ret = str; + int space_num = 0; + auto space_char = ' '; + for (size_t i = 0; i < ret.size(); i++) { + auto pos = i; + char c = ret[pos]; + if (c == '\n') { + size_t next_pos = pos + 1; + while (next_pos < ret.size() && ret[next_pos] == space_char) { + next_pos++; + } + space_num = next_pos - pos - 1; + } + if (c == ';' && (pos + 1 < ret.size()) && ret[pos + 1] != '\n') { + auto insert_pos = pos + 1; + std::string insert_str = "\n" + std::string(space_num, space_char); + ret.insert(insert_pos, insert_str); + space_num = 0; + } + } + + return ret; + } + + private: + std::string template_str_; +}; + +static std::string EmitUniqueName(std::vector expression) { + std::stringstream ret; + ret << "fused_kernel"; + for (size_t i = 0; i < expression.size(); i++) { + ret << expression[i].GetOutputId(); + } + return ret.str(); +} +// we get the parameter list code for the expression information +static std::string EmitDeclarationCode( + std::vector expression, std::string type) { + std::stringstream ret; + + std::set input_ids; + std::set output_ids; + + for (size_t i = 0; i < expression.size(); i++) { + std::vector tmp_input = expression[i].GetInputIds(); + for (size_t j = 0; j < tmp_input.size(); j++) { + int id = tmp_input[j]; + input_ids.insert(id); + } + int tmp_output = expression[i].GetOutputId(); + output_ids.insert(tmp_output); + } + + std::set::iterator it = input_ids.begin(); + while (it != input_ids.end()) { + int var_index = *it; + if (output_ids.find(var_index) != output_ids.end()) { + input_ids.erase(it++); + } else { + it++; + } + } + + ret << "int N, "; + for (it = input_ids.begin(); it != input_ids.end(); it++) { + int var_index = *it; + ret << type << R"(* var)" << var_index; + ret << ", "; + } + + size_t count_index = 0; + for (it = output_ids.begin(); it != output_ids.end(); it++) { + int var_index = *it; + ret << type << R"(* var)" << var_index; + if (count_index != output_ids.size() - 1) { + ret << ", "; + } + count_index++; + } + + return ret.str(); +} + +static std::string EmitComputeCode( + std::vector expression) { + // get the right experssion code using suffix expression + std::stringstream ret; + for (size_t i = 0; i < expression.size(); i++) { + ret << expression[i].GetExpression(); + } + return ret.str(); +} + +static const char kernel_function[] = R"( +__device__ float real_exp(float x) { return ::expf(x); } + +__device__ double real_exp(double x) { return ::exp(x); } + +__device__ float real_log(float x) { return ::logf(x); } + +__device__ double real_log(double x) { return ::log(x); } + +__device__ float real_min(float x, float y) { return ::fminf(x, y); } + +__device__ double real_min(double x, double y) { return ::fmin(x, y); } + +__device__ float real_max(float x, float y) { return ::fmaxf(x, y); } + +__device__ double real_max(double x, double y) { return ::fmax(x, y); } -static const char const_kernel_start[] = R"( -template -extern "C" __global__ void )"; -static const char const_kernel_mid[] = R"( -{ +static const char kernel_elementwise_template[] = R"( + +extern "C" __global__ void $name($parameter){ for(int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < N; idx += gridDim.x * blockDim.x) { - -)"; - -static const char const_kernel_end[] = R"( + $compute } } )"; + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/codegen_test.cc b/paddle/fluid/framework/ir/codegen_test.cc index 8fd5fde3df2c1a1876b346f747f9158a3d40499b..7877b218484011e2aa49d42762ce47d0c5895aac 100644 --- a/paddle/fluid/framework/ir/codegen_test.cc +++ b/paddle/fluid/framework/ir/codegen_test.cc @@ -1,43 +1,140 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/fluid/framework/ir/codegen.h" #include +#include #include #include #include "paddle/fluid/framework/ir/codegen_helper.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/operators/math.h" +#include "paddle/fluid/platform/device_code.h" +#include "paddle/fluid/platform/init.h" #ifdef PADDLE_WITH_CUDA + TEST(codegen, cuda) { std::vector mul_input{1, 2}; std::vector add_input{3, 4}; - std::vector sigmod_input{5}; + std::vector sub_input{5, 6}; + std::vector relu_input{7}; + std::vector sigmoid_input{8}; + int mul_out = 3; int add_out = 5; - int sigmod_out = 6; + int sub_out = 7; + int relu_out = 8; + int sigmoid_out = 9; std::string op1 = "elementwise_mul"; std::string op2 = "elementwise_add"; - std::string op3 = "sigmoid"; + std::string op3 = "elementwise_sub"; + std::string op4 = "relu"; + std::string op5 = "sigmoid"; paddle::framework::ir::OperationExpression opexp1(mul_input, mul_out, op1); paddle::framework::ir::OperationExpression opexp2(add_input, add_out, op2); - paddle::framework::ir::OperationExpression opexp3(sigmod_input, sigmod_out, - op3); + paddle::framework::ir::OperationExpression opexp3(sub_input, sub_out, op3); + paddle::framework::ir::OperationExpression opexp4(relu_input, relu_out, op4); + paddle::framework::ir::OperationExpression opexp5(sigmoid_input, sigmoid_out, + op5); std::vector fused_op = { - opexp1, opexp2, opexp3}; - paddle::framework::ir::CodeGen codegen; - std::string result = codegen.GetKernelCode(fused_op); - std::cout << result << std::endl; + opexp1, opexp2, opexp3, opexp4, opexp5}; + paddle::framework::ir::CodeTemplate code_template( + paddle::framework::ir::kernel_elementwise_template); + paddle::framework::ir::CodeGenerator codegen(code_template); + paddle::framework::ir::TemplateVariable template_var; + template_var.Add("$name", EmitUniqueName(fused_op)); + template_var.Add("$parameter", EmitDeclarationCode(fused_op, "float")); + template_var.Add("$compute", EmitComputeCode(fused_op)); + std::string saxpy_code = codegen.GenerateCode(template_var); + + std::cout << saxpy_code << std::endl; + paddle::framework::InitDevices(false, {0}); + paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0); + paddle::platform::CUDADeviceCode code(place, EmitUniqueName(fused_op), + saxpy_code); + + paddle::framework::Tensor cpu_a; + paddle::framework::Tensor cpu_b; + paddle::framework::Tensor cpu_c; + paddle::framework::Tensor cpu_d; + paddle::framework::Tensor cpu_e; + paddle::framework::Tensor cpu_f; + paddle::framework::Tensor cpu_g; + paddle::framework::Tensor cpu_h; + paddle::framework::Tensor cpu_o; + + auto dims = paddle::framework::make_ddim( + {static_cast(256), static_cast(1024)}); + cpu_a.mutable_data(dims, paddle::platform::CPUPlace()); + cpu_b.mutable_data(dims, paddle::platform::CPUPlace()); + cpu_c.mutable_data(dims, paddle::platform::CPUPlace()); + cpu_d.mutable_data(dims, paddle::platform::CPUPlace()); + cpu_e.mutable_data(dims, paddle::platform::CPUPlace()); + cpu_f.mutable_data(dims, paddle::platform::CPUPlace()); + cpu_g.mutable_data(dims, paddle::platform::CPUPlace()); + cpu_o.mutable_data(dims, paddle::platform::CPUPlace()); + + size_t n = cpu_a.numel(); + for (size_t i = 0; i < n; ++i) { + cpu_a.data()[i] = static_cast(i); + } + for (size_t i = 0; i < n; ++i) { + cpu_b.data()[i] = static_cast(0.5); + cpu_d.data()[i] = static_cast(10.0); + cpu_f.data()[i] = static_cast(0.0); + } + + paddle::framework::Tensor a; + paddle::framework::Tensor b; + paddle::framework::Tensor c; + paddle::framework::Tensor d; + paddle::framework::Tensor e; + paddle::framework::Tensor f; + paddle::framework::Tensor g; + paddle::framework::Tensor h; + paddle::framework::Tensor o; + + float* a_data = a.mutable_data(dims, place); + float* b_data = b.mutable_data(dims, place); + float* c_data = c.mutable_data(dims, place); + float* d_data = d.mutable_data(dims, place); + float* e_data = e.mutable_data(dims, place); + float* f_data = f.mutable_data(dims, place); + float* g_data = g.mutable_data(dims, place); + float* h_data = h.mutable_data(dims, place); + float* o_data = o.mutable_data(dims, place); + + TensorCopySync(cpu_a, place, &a); + TensorCopySync(cpu_b, place, &b); + TensorCopySync(cpu_d, place, &d); + TensorCopySync(cpu_f, place, &f); + + code.Compile(); + + std::vector args = {&n, &a_data, &b_data, &d_data, &f_data, + &c_data, &e_data, &g_data, &h_data, &o_data}; + code.SetNumThreads(1024); + code.SetWorkloadPerThread(1); + code.Launch(n, &args); + + TensorCopySync(o, paddle::platform::CPUPlace(), &cpu_o); + for (size_t i = 0; i < n; i++) { + float result = + (1.0 / (1.0 + std::exp(-std::max( + 0.0, static_cast(i) * 0.5 + 10.0 - 0.0)))); + PADDLE_ENFORCE_EQ(cpu_o.data()[i], result); + } } #endif