未验证 提交 c9ea317b 编写于 作者: W wangchaochaohu 提交者: GitHub

codegen code for reconstruction (#19728)

* codegen code for reconstruction test=develop

* fix the cmake test=develop

* fix review advice test=develop
上级 647ff784
...@@ -110,7 +110,9 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") ...@@ -110,7 +110,9 @@ set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
cc_library(pass_builder SRCS pass_builder.cc DEPS pass) cc_library(pass_builder SRCS pass_builder.cc DEPS pass)
cc_test(codegen_test SRCS codegen_test.cc DEPS codegen_helper codegen) if(WITH_GPU)
cc_test(codegen_test SRCS codegen_test.cc DEPS codegen_helper codegen device_code lod_tensor)
endif()
cc_test(node_test SRCS node_test.cc DEPS node) cc_test(node_test SRCS node_test.cc DEPS node)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -19,76 +19,15 @@ namespace paddle { ...@@ -19,76 +19,15 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// we get the parameter list code for the expression information CodeGenerator::CodeGenerator(CodeTemplate code_template) {
std::string CodeGen::GetDeclarationCode( code_template_ = code_template;
std::vector<OperationExpression> expression) {
std::stringstream ret;
ret << "fuse_kernel";
ret << R"((int N )";
std::set<int> input_ids;
std::set<int> output_ids;
std::vector<int> last_output_idis;
for (size_t i = 0; i < expression.size(); i++) {
std::vector<int> tmp_input = expression[i].GetInputIds();
for (size_t j = 0; j < tmp_input.size(); j++) {
int id = tmp_input[j];
input_ids.insert(id);
}
int tmp_output = expression[i].GetOutputId();
output_ids.insert(tmp_output);
}
std::set<int>::iterator it = input_ids.begin();
while (it != input_ids.end()) {
int var_index = *it;
if (output_ids.find(var_index) != output_ids.end()) {
input_ids.erase(it++);
} else {
it++;
}
}
for (it = input_ids.begin(); it != input_ids.end(); it++) {
int var_index = *it;
ret << R"(, const T* var)" << var_index;
}
for (it = output_ids.begin(); it != output_ids.end(); it++) {
int var_index = *it;
ret << R"(, T* var)" << var_index;
}
ret << R"())";
return ret.str();
}
std::string CodeGen::GetOffsetCode() {
std::stringstream ret;
ret << indentation << "int offset = idx;" << std::endl;
return ret.str();
} }
std::string CodeGen::GetComputeCode(
std::vector<OperationExpression> expression) {
// get the right experssion code using suffix expression
std::stringstream ret;
for (size_t i = 0; i < expression.size(); i++) {
ret << expression[i].GetExpression();
}
return ret.str();
}
// in order to get the right result of expression, we need to calculate, we // in order to get the right result of expression, we need to calculate, we
// store the expression as // store the expression as
// suffix Expressions using vector // suffix Expressions using vector
std::string CodeGen::GetKernelCode( std::string CodeGenerator::GenerateCode(TemplateVariable template_var) {
std::vector<OperationExpression> expression) { auto cuda_kernel = kernel_function + code_template_.Format(template_var);
auto declaration_code = GetDeclarationCode(expression);
auto offset_code = GetOffsetCode();
auto compute_code = GetComputeCode(expression);
auto cuda_kernel = const_kernel_start + declaration_code + const_kernel_mid +
offset_code + compute_code + const_kernel_end;
return cuda_kernel; return cuda_kernel;
} }
} // namespace ir } // namespace ir
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -20,16 +20,14 @@ namespace paddle { ...@@ -20,16 +20,14 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class CodeGen { class CodeGenerator {
public: public:
std::string GetKernelCode(std::vector<OperationExpression> expression); explicit CodeGenerator(CodeTemplate code_template);
std::string GenerateCode(TemplateVariable template_var);
// TODO(wangchao66) std::string GenerateCode(const Graph& graph)
private: private:
std::string GetDeclarationCode( CodeTemplate code_template_;
std::vector<paddle::framework::ir::OperationExpression> expression);
std::string GetOffsetCode();
std::string GetComputeCode(
std::vector<paddle::framework::ir::OperationExpression> expression);
}; };
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -21,41 +21,46 @@ namespace framework { ...@@ -21,41 +21,46 @@ namespace framework {
namespace ir { namespace ir {
OperationExpression::OperationExpression(std::vector<int> input_ids, OperationExpression::OperationExpression(std::vector<int> input_ids,
int output_id, int output_id, std::string op) {
std::string search_operation) {
input_ids_ = input_ids; input_ids_ = input_ids;
output_id_ = output_id; output_id_ = output_id;
search_operation_ = search_operation; op_ = op;
} }
std::string OperationExpression::GetRHSTemplate() {
std::stringstream ret;
std::string rhs_end = ";";
auto rhs = support_table[op_];
for (size_t i = 0; i < input_ids_.size(); i++) {
auto replaced_str = replaced_element_in_order[i];
auto pos = rhs.find(replaced_str);
auto index = input_ids_[i];
rhs.replace(pos, replaced_str.length(), std::to_string(index) + R"([idx])");
}
ret << rhs << rhs_end;
return ret.str();
}
std::string OperationExpression::GetLHSTemplate() {
std::stringstream ret;
ret << "var" << output_id_ << R"([idx] = )";
return ret.str();
}
bool OperationExpression::SupportState() {
return (support_table.find(op_) == support_table.end());
}
// we Traverse the graph and get the group , all input id and output id is // we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group // unique for the node which belong the group
std::string OperationExpression::GetExpression() { std::string OperationExpression::GetExpression() {
std::stringstream ret; std::stringstream ret;
if (operator_cuda_table.find(search_operation_) == if (!SupportState()) {
operator_cuda_table.end()) { ret << GetLHSTemplate() << GetRHSTemplate();
std::cerr << "Not supportted operation, " << search_operation_ << std::endl;
} else {
auto rhs = operator_cuda_table[search_operation_];
std::string replaced_str = "$";
int count = 0;
auto pos = rhs.find(replaced_str);
while (pos != -1) {
auto index = input_ids_[count];
rhs.replace(pos, replaced_str.length(),
std::to_string(index) + R"([offset])");
pos = rhs.find(replaced_str);
count++;
}
auto lhs = std::string(indentation) + "var" + std::to_string(output_id_) +
R"([offset])";
auto equal_split = R"( = )";
auto semicolon = R"(;)";
ret << lhs << equal_split << rhs << semicolon << std::endl;
} }
return ret.str(); return ret.str();
} }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <set>
#include <sstream>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -21,50 +23,218 @@ limitations under the License. */ ...@@ -21,50 +23,218 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static std::unordered_map<std::string, std::string> operator_cuda_table = {
{"elementwise_add", "var$ + var$"},
{"elementwise_sub", "var$ - var$"},
{"elementwise_mul", "var$ * var$"},
{"elementwise_div", "var$ / var$"},
{"elementwise_min", "real_min(var$, var$)"},
{"elementwise_max", "real_max(var$, var$)"},
{"relu", "real_max(var$, 0)"},
{"sigmoid", "1.0 / (1.0 + real_exp(-var$))"}};
static std::vector<std::string> replaced_element_in_order = {"@", "$"};
static std::vector<std::string> kernel_template = {"$name", "$parameter",
"$compute"};
static std::unordered_map<std::string, std::string> support_table = {
{"elementwise_add", "var@ + var$"},
{"elementwise_sub", "var@ - var$"},
{"elementwise_mul", "var@ * var$"},
{"elementwise_div", "var@ / var$"},
{"elementwise_min", "real_min(var@, var$)"},
{"elementwise_max", "real_max(var@, var$)"},
{"relu", "real_max(var@, 0)"},
{"sigmoid", "1.0 / (1.0 + real_exp(-var@))"}};
// Paddle elementwise op consist the broacast op and elementwise op
// op computation is composed by single or many operation // op computation is composed by single or many operation
// here we only generate the simple expression code so we
// make it simple
class OperationExpression { class OperationExpression {
public: public:
OperationExpression(std::vector<int> input_ids, int output_id, OperationExpression(std::vector<int> input_ids, int output_id,
std::string search_oprtation); std::string op);
std::string GetExpression(); std::string GetExpression();
std::vector<int> GetInputIds() { return input_ids_; } std::vector<int> GetInputIds() { return input_ids_; }
int GetOutputId() { return output_id_; } int GetOutputId() { return output_id_; }
bool SupportState();
// in oreder to make offset more flexible we add stride and basic offset
std::string GetRHSTemplate();
std::string GetLHSTemplate();
private: private:
std::vector<int> input_ids_; std::vector<int> input_ids_;
int output_id_; int output_id_;
std::string search_operation_; std::string op_;
}; };
static const char indentation[] = R"( )"; class TemplateVariable {
public:
void Add(std::string identifier, std::string expression) {
strings_[identifier] = expression;
}
void Remove(std::string identifier, std::string expression) {
for (auto it = strings_.begin(); it != strings_.end();) {
if (it->first == identifier) {
it = strings_.erase(it);
} else {
it++;
}
}
}
std::unordered_map<std::string, std::string> Get() { return strings_; }
private:
std::unordered_map<std::string, std::string> strings_;
};
class CodeTemplate {
public:
CodeTemplate() = default;
explicit CodeTemplate(std::string template_str) {
template_str_ = template_str;
}
std::string Format(TemplateVariable template_var) {
std::string ret = template_str_;
std::unordered_map<std::string, std::string> identifier_str =
template_var.Get();
for (size_t i = 0; i < ret.size(); i++) {
auto pos = i;
char c = ret[pos];
if (c == '$') {
for (size_t j = 0; j < kernel_template.size(); j++) {
int template_size = kernel_template[j].size();
auto tmp_cmp = ret.substr(pos, template_size);
if (tmp_cmp == kernel_template[j]) {
ret.replace(pos, template_size, identifier_str[kernel_template[j]]);
}
}
}
}
return EmitIndents(ret);
}
std::string EmitIndents(std::string str) {
std::string ret = str;
int space_num = 0;
auto space_char = ' ';
for (size_t i = 0; i < ret.size(); i++) {
auto pos = i;
char c = ret[pos];
if (c == '\n') {
size_t next_pos = pos + 1;
while (next_pos < ret.size() && ret[next_pos] == space_char) {
next_pos++;
}
space_num = next_pos - pos - 1;
}
if (c == ';' && (pos + 1 < ret.size()) && ret[pos + 1] != '\n') {
auto insert_pos = pos + 1;
std::string insert_str = "\n" + std::string(space_num, space_char);
ret.insert(insert_pos, insert_str);
space_num = 0;
}
}
return ret;
}
private:
std::string template_str_;
};
static std::string EmitUniqueName(std::vector<OperationExpression> expression) {
std::stringstream ret;
ret << "fused_kernel";
for (size_t i = 0; i < expression.size(); i++) {
ret << expression[i].GetOutputId();
}
return ret.str();
}
// we get the parameter list code for the expression information
static std::string EmitDeclarationCode(
std::vector<OperationExpression> expression, std::string type) {
std::stringstream ret;
std::set<int> input_ids;
std::set<int> output_ids;
for (size_t i = 0; i < expression.size(); i++) {
std::vector<int> tmp_input = expression[i].GetInputIds();
for (size_t j = 0; j < tmp_input.size(); j++) {
int id = tmp_input[j];
input_ids.insert(id);
}
int tmp_output = expression[i].GetOutputId();
output_ids.insert(tmp_output);
}
std::set<int>::iterator it = input_ids.begin();
while (it != input_ids.end()) {
int var_index = *it;
if (output_ids.find(var_index) != output_ids.end()) {
input_ids.erase(it++);
} else {
it++;
}
}
ret << "int N, ";
for (it = input_ids.begin(); it != input_ids.end(); it++) {
int var_index = *it;
ret << type << R"(* var)" << var_index;
ret << ", ";
}
size_t count_index = 0;
for (it = output_ids.begin(); it != output_ids.end(); it++) {
int var_index = *it;
ret << type << R"(* var)" << var_index;
if (count_index != output_ids.size() - 1) {
ret << ", ";
}
count_index++;
}
return ret.str();
}
static std::string EmitComputeCode(
std::vector<OperationExpression> expression) {
// get the right experssion code using suffix expression
std::stringstream ret;
for (size_t i = 0; i < expression.size(); i++) {
ret << expression[i].GetExpression();
}
return ret.str();
}
static const char kernel_function[] = R"(
__device__ float real_exp(float x) { return ::expf(x); }
__device__ double real_exp(double x) { return ::exp(x); }
__device__ float real_log(float x) { return ::logf(x); }
__device__ double real_log(double x) { return ::log(x); }
__device__ float real_min(float x, float y) { return ::fminf(x, y); }
__device__ double real_min(double x, double y) { return ::fmin(x, y); }
__device__ float real_max(float x, float y) { return ::fmaxf(x, y); }
__device__ double real_max(double x, double y) { return ::fmax(x, y); }
static const char const_kernel_start[] = R"(
template <typename T>
extern "C" __global__ void
)"; )";
static const char const_kernel_mid[] = R"( static const char kernel_elementwise_template[] = R"(
{
extern "C" __global__ void $name($parameter){
for(int idx = blockIdx.x * blockDim.x + threadIdx.x; for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < N; idx < N;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
$compute
)";
static const char const_kernel_end[] = R"(
} }
} }
)"; )";
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/codegen.h" #include "paddle/fluid/framework/ir/codegen.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <cmath>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/codegen_helper.h" #include "paddle/fluid/framework/ir/codegen_helper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/device_code.h"
#include "paddle/fluid/platform/init.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
TEST(codegen, cuda) { TEST(codegen, cuda) {
std::vector<int> mul_input{1, 2}; std::vector<int> mul_input{1, 2};
std::vector<int> add_input{3, 4}; std::vector<int> add_input{3, 4};
std::vector<int> sigmod_input{5}; std::vector<int> sub_input{5, 6};
std::vector<int> relu_input{7};
std::vector<int> sigmoid_input{8};
int mul_out = 3; int mul_out = 3;
int add_out = 5; int add_out = 5;
int sigmod_out = 6; int sub_out = 7;
int relu_out = 8;
int sigmoid_out = 9;
std::string op1 = "elementwise_mul"; std::string op1 = "elementwise_mul";
std::string op2 = "elementwise_add"; std::string op2 = "elementwise_add";
std::string op3 = "sigmoid"; std::string op3 = "elementwise_sub";
std::string op4 = "relu";
std::string op5 = "sigmoid";
paddle::framework::ir::OperationExpression opexp1(mul_input, mul_out, op1); paddle::framework::ir::OperationExpression opexp1(mul_input, mul_out, op1);
paddle::framework::ir::OperationExpression opexp2(add_input, add_out, op2); paddle::framework::ir::OperationExpression opexp2(add_input, add_out, op2);
paddle::framework::ir::OperationExpression opexp3(sigmod_input, sigmod_out, paddle::framework::ir::OperationExpression opexp3(sub_input, sub_out, op3);
op3); paddle::framework::ir::OperationExpression opexp4(relu_input, relu_out, op4);
paddle::framework::ir::OperationExpression opexp5(sigmoid_input, sigmoid_out,
op5);
std::vector<paddle::framework::ir::OperationExpression> fused_op = { std::vector<paddle::framework::ir::OperationExpression> fused_op = {
opexp1, opexp2, opexp3}; opexp1, opexp2, opexp3, opexp4, opexp5};
paddle::framework::ir::CodeGen codegen; paddle::framework::ir::CodeTemplate code_template(
std::string result = codegen.GetKernelCode(fused_op); paddle::framework::ir::kernel_elementwise_template);
std::cout << result << std::endl; paddle::framework::ir::CodeGenerator codegen(code_template);
paddle::framework::ir::TemplateVariable template_var;
template_var.Add("$name", EmitUniqueName(fused_op));
template_var.Add("$parameter", EmitDeclarationCode(fused_op, "float"));
template_var.Add("$compute", EmitComputeCode(fused_op));
std::string saxpy_code = codegen.GenerateCode(template_var);
std::cout << saxpy_code << std::endl;
paddle::framework::InitDevices(false, {0});
paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceCode code(place, EmitUniqueName(fused_op),
saxpy_code);
paddle::framework::Tensor cpu_a;
paddle::framework::Tensor cpu_b;
paddle::framework::Tensor cpu_c;
paddle::framework::Tensor cpu_d;
paddle::framework::Tensor cpu_e;
paddle::framework::Tensor cpu_f;
paddle::framework::Tensor cpu_g;
paddle::framework::Tensor cpu_h;
paddle::framework::Tensor cpu_o;
auto dims = paddle::framework::make_ddim(
{static_cast<int64_t>(256), static_cast<int64_t>(1024)});
cpu_a.mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_b.mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_c.mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_d.mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_e.mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_f.mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_g.mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_o.mutable_data<float>(dims, paddle::platform::CPUPlace());
size_t n = cpu_a.numel();
for (size_t i = 0; i < n; ++i) {
cpu_a.data<float>()[i] = static_cast<float>(i);
}
for (size_t i = 0; i < n; ++i) {
cpu_b.data<float>()[i] = static_cast<float>(0.5);
cpu_d.data<float>()[i] = static_cast<float>(10.0);
cpu_f.data<float>()[i] = static_cast<float>(0.0);
}
paddle::framework::Tensor a;
paddle::framework::Tensor b;
paddle::framework::Tensor c;
paddle::framework::Tensor d;
paddle::framework::Tensor e;
paddle::framework::Tensor f;
paddle::framework::Tensor g;
paddle::framework::Tensor h;
paddle::framework::Tensor o;
float* a_data = a.mutable_data<float>(dims, place);
float* b_data = b.mutable_data<float>(dims, place);
float* c_data = c.mutable_data<float>(dims, place);
float* d_data = d.mutable_data<float>(dims, place);
float* e_data = e.mutable_data<float>(dims, place);
float* f_data = f.mutable_data<float>(dims, place);
float* g_data = g.mutable_data<float>(dims, place);
float* h_data = h.mutable_data<float>(dims, place);
float* o_data = o.mutable_data<float>(dims, place);
TensorCopySync(cpu_a, place, &a);
TensorCopySync(cpu_b, place, &b);
TensorCopySync(cpu_d, place, &d);
TensorCopySync(cpu_f, place, &f);
code.Compile();
std::vector<void*> args = {&n, &a_data, &b_data, &d_data, &f_data,
&c_data, &e_data, &g_data, &h_data, &o_data};
code.SetNumThreads(1024);
code.SetWorkloadPerThread(1);
code.Launch(n, &args);
TensorCopySync(o, paddle::platform::CPUPlace(), &cpu_o);
for (size_t i = 0; i < n; i++) {
float result =
(1.0 / (1.0 + std::exp(-std::max(
0.0, static_cast<float>(i) * 0.5 + 10.0 - 0.0))));
PADDLE_ENFORCE_EQ(cpu_o.data<float>()[i], result);
}
} }
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册