未验证 提交 9091f8cd 编写于 作者: Y Yiqun Liu 提交者: GitHub

Support generating code for grad_op (#21066)

* Add the definition of operation in fusion_group.

* Use operations in OperationMap to detect fusion_group of elementwise pattern.

* Add namespace fusion_group in code_generator.

* Use operations recorded in OperationMap to generate code.

* Remove implementation codes to .cc file.

* Refine Operation and CodeGenerator to make it easier to generate code for grad_op.
Refine the unittest for better reuse.

* Avoid recording the template's keyword in a array.

* Support the generating of code for grad_op and add unittest.
test=develop

* Remove replaced_element_in_order and use use number instead.
test=develop
上级 1cd67218
cc_library(code_generator SRCS code_generator.cc code_generator_helper.cc DEPS graph) cc_library(code_generator SRCS operation.cc code_generator.cc code_generator_helper.cc DEPS graph)
if(NOT APPLE AND NOT WIN32) if(NOT APPLE AND NOT WIN32)
if(WITH_GPU) if(WITH_GPU)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor) cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor)
...@@ -7,5 +7,5 @@ endif() ...@@ -7,5 +7,5 @@ endif()
cc_library(fusion_group_pass cc_library(fusion_group_pass
SRCS fusion_group_pass.cc elementwise_group_detector.cc SRCS fusion_group_pass.cc elementwise_group_detector.cc
DEPS graph_pattern_detector pass) DEPS graph_pattern_detector pass code_generator)
cc_test(test_fusion_group_pass SRCS fusion_group_pass_tester.cc DEPS fusion_group_pass graph_viz_pass) cc_test(test_fusion_group_pass SRCS fusion_group_pass_tester.cc DEPS fusion_group_pass graph_viz_pass)
...@@ -20,17 +20,82 @@ limitations under the License. */ ...@@ -20,17 +20,82 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
namespace fusion_group {
CodeGenerator::CodeGenerator(CodeTemplate code_template) { CodeGenerator::CodeGenerator() {
code_template_ = code_template; // Only support elementwise operations now.
code_templates_.resize(1);
CodeTemplate elementwise_t(elementwise_cuda_template);
code_templates_[0] = elementwise_t;
} }
// In order to get the right result of expression, we need to calculate and // In order to get the right result of expression, we need to calculate and
// store the expression as suffix Expressions using vector. // store the expression as suffix Expressions using vector.
std::string CodeGenerator::GenerateCode(TemplateVariable template_var) { std::string CodeGenerator::GenerateCode(
auto cuda_kernel = kernel_function + code_template_.Format(template_var); std::string func_name, std::vector<OperationExpression> expressions) {
return cuda_kernel; // Check whether all expressions are elementwise operations.
TemplateVariable template_var;
template_var.Add("func_name", func_name);
template_var.Add("parameters", EmitParameters(expressions, "float"));
template_var.Add("compute_body", EmitComputeBody(expressions));
return predefined_cuda_functions + code_templates_[0].Format(template_var);
}
// we get the parameter list code for the expression information
std::string CodeGenerator::EmitParameters(
std::vector<OperationExpression> expressions, std::string dtype) {
std::set<int> input_ids;
std::set<int> output_ids;
// Remove the reptead id and get a ordered list.
for (size_t i = 0; i < expressions.size(); i++) {
for (auto id : expressions[i].GetInputIds()) {
input_ids.insert(id);
}
for (auto id : expressions[i].GetOutputIds()) {
output_ids.insert(id);
}
}
// If a id is in the input and output list at the same time, then remove it
// from the input list.
for (auto iter = input_ids.begin(); iter != input_ids.end();) {
if (output_ids.find(*iter) != output_ids.end()) {
input_ids.erase(iter++);
} else {
iter++;
}
}
std::stringstream ret;
ret << "int N, ";
for (auto iter = input_ids.begin(); iter != input_ids.end(); iter++) {
ret << dtype << "* " << VarName(*iter) << ", ";
}
size_t count_index = 0;
for (auto iter = output_ids.begin(); iter != output_ids.end(); iter++) {
ret << dtype << "* " << VarName(*iter);
if (count_index != output_ids.size() - 1) {
ret << ", ";
}
count_index++;
}
return ret.str();
} }
std::string CodeGenerator::EmitComputeBody(
std::vector<OperationExpression> expressions) {
// get the right experssion code using suffix expression
std::stringstream ret;
for (size_t i = 0; i < expressions.size(); i++) {
ret << expressions[i].GetExpression();
}
return ret.str();
}
} // namespace fusion_group
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h" #include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
...@@ -20,19 +21,30 @@ limitations under the License. */ ...@@ -20,19 +21,30 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
namespace fusion_group {
class CodeGenerator { class CodeGenerator {
public: public:
explicit CodeGenerator(CodeTemplate code_template); CodeGenerator();
std::string GenerateCode(TemplateVariable template_var); std::string GenerateCode(std::string func_name,
std::vector<OperationExpression> expressions);
// TODO(wangchao): add a more general interface // TODO(wangchao): add a more general interface
// std::string Generate(const std::string name, const SubGraph& subgraph); // std::string Generate(const std::string name, const SubGraph& subgraph);
private: private:
CodeTemplate code_template_; // we get the parameter list code for the expression information
std::string EmitParameters(std::vector<OperationExpression> expressions,
std::string dtype);
std::string EmitComputeBody(std::vector<OperationExpression> expressions);
private:
std::vector<CodeTemplate> code_templates_;
}; };
} // namespace fusion_group
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -17,53 +17,66 @@ limitations under the License. */ ...@@ -17,53 +17,66 @@ limitations under the License. */
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
namespace fusion_group {
OperationExpression::OperationExpression(std::vector<int> input_ids, template <typename T>
int output_id, std::string op) { static T StringTo(const std::string& str) {
input_ids_ = input_ids; std::istringstream is(str);
output_id_ = output_id; T value;
op_ = op; is >> value;
return value;
} }
std::string OperationExpression::GetRHSTemplate() { std::string OperationExpression::GetRHS(size_t i) {
std::stringstream ret; auto rhs = OperationMap::Instance().Get(op_).exprs[i];
std::string rhs_end = ";"; for (size_t i = 0; i < rhs.size(); i++) {
auto rhs = support_table[op_]; size_t pos = i;
for (size_t i = 0; i < input_ids_.size(); i++) { if (rhs[pos] == '$' && rhs[pos + 1] == '{') {
auto replaced_str = replaced_element_in_order[i]; int length = 0;
auto pos = rhs.find(replaced_str); while (rhs[pos + 2 + length] != '}') {
auto index = input_ids_[i]; length++;
rhs.replace(pos, replaced_str.length(), std::to_string(index) + R"([idx])"); }
std::string index_str = rhs.substr(pos + 2, length);
int index = StringTo<int>(index_str);
PADDLE_ENFORCE_LT(index, input_ids_.size(),
"Only %d inputs are provided, but need %d.",
input_ids_.size(), index + 1);
rhs.replace(pos, length + 3, VarName(input_ids_[index]) + R"([idx])");
}
} }
ret << rhs << rhs_end; return rhs;
return ret.str();
} }
std::string OperationExpression::GetLHSTemplate() { std::string OperationExpression::GetLHS(size_t i) {
std::stringstream ret; std::stringstream ret;
ret << "var" << output_id_ << R"([idx] = )"; ret << VarName(output_ids_[i]) << R"([idx])";
return ret.str(); return ret.str();
} }
bool OperationExpression::SupportState() { bool OperationExpression::IsSupport() {
return (support_table.find(op_) == support_table.end()); return OperationMap::Instance().Has(op_);
} }
// we Traverse the graph and get the group , all input id and output id is // we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group // unique for the node which belong the group
std::string OperationExpression::GetExpression() { std::string OperationExpression::GetExpression() {
std::stringstream ret; std::stringstream ret;
if (!SupportState()) { if (IsSupport()) {
ret << GetLHSTemplate() << GetRHSTemplate(); for (size_t i = 0; i < output_ids_.size(); ++i) {
ret << GetLHS(i) << " = " << GetRHS(i) << ";";
}
} }
return ret.str(); return ret.str();
} }
} // namespace fusion_group
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,52 +14,43 @@ limitations under the License. */ ...@@ -14,52 +14,43 @@ limitations under the License. */
#pragma once #pragma once
#include <iostream>
#include <set> #include <set>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
namespace fusion_group {
static std::string VarName(int index) { return "var" + std::to_string(index); }
static std::vector<std::string> replaced_element_in_order = {"@", "$"};
static std::vector<std::string> kernel_template = {"$name", "$parameter",
"$compute"};
static std::unordered_map<std::string, std::string> support_table = {
{"elementwise_add", "var@ + var$"},
{"elementwise_sub", "var@ - var$"},
{"elementwise_mul", "var@ * var$"},
{"elementwise_div", "var@ / var$"},
{"elementwise_min", "real_min(var@, var$)"},
{"elementwise_max", "real_max(var@, var$)"},
{"relu", "real_max(var@, 0)"},
{"sigmoid", "1.0 / (1.0 + real_exp(-var@))"}};
// Paddle elementwise op consist the broacast op and elementwise op
// op computation is composed by single or many operation
// here we only generate the simple expression code so we
// make it simple
class OperationExpression { class OperationExpression {
public: public:
OperationExpression(std::vector<int> input_ids, int output_id, explicit OperationExpression(std::string op, std::vector<int> input_ids,
std::string op); std::vector<int> output_ids)
std::string GetExpression(); : op_(op), input_ids_(input_ids), output_ids_(output_ids) {}
std::vector<int> GetInputIds() { return input_ids_; } std::vector<int> GetInputIds() { return input_ids_; }
int GetOutputId() { return output_id_; } std::vector<int> GetOutputIds() { return output_ids_; }
bool SupportState();
// in oreder to make offset more flexible we add stride and basic offset // Check whether this operation type is supported in OperationMap.
std::string GetRHSTemplate(); bool IsSupport();
std::string GetLHSTemplate();
std::string GetExpression();
// TODO(wangchao): make offset more flexible we add stride and basic offset
std::string GetRHS(size_t i = 0);
std::string GetLHS(size_t i = 0);
private: private:
std::vector<int> input_ids_;
int output_id_;
std::string op_; std::string op_;
std::vector<int> input_ids_;
std::vector<int> output_ids_;
}; };
class TemplateVariable { class TemplateVariable {
...@@ -92,24 +83,30 @@ class CodeTemplate { ...@@ -92,24 +83,30 @@ class CodeTemplate {
std::string Format(TemplateVariable template_var) { std::string Format(TemplateVariable template_var) {
std::string ret = template_str_; std::string ret = template_str_;
std::unordered_map<std::string, std::string> identifier_str = std::unordered_map<std::string, bool> found;
template_var.Get();
// Word begins with "$" in template_str will be replaced.
for (size_t i = 0; i < ret.size(); i++) { for (size_t i = 0; i < ret.size(); i++) {
auto pos = i; auto pos = i;
char c = ret[pos]; char c = ret[pos];
if (c == '$') { if (c == '$') {
for (size_t j = 0; j < kernel_template.size(); j++) { for (auto iter : template_var.Get()) {
int template_size = kernel_template[j].size(); std::string keyword = iter.first;
auto tmp_cmp = ret.substr(pos, template_size); if (ret.substr(pos + 1, keyword.size()) == keyword) {
if (tmp_cmp == kernel_template[j]) { found[keyword] = true;
ret.replace(pos, template_size, identifier_str[kernel_template[j]]); ret.replace(pos, keyword.size() + 1, iter.second);
break;
} }
} }
} }
} }
for (auto iter : template_var.Get()) {
PADDLE_ENFORCE_NE(found.find(iter.first), found.end(),
"Keyword %s in template is not set.", iter.first);
}
return EmitIndents(ret); return EmitIndents(ret);
} }
...@@ -142,103 +139,33 @@ class CodeTemplate { ...@@ -142,103 +139,33 @@ class CodeTemplate {
std::string template_str_; std::string template_str_;
}; };
static std::string EmitUniqueName(std::vector<OperationExpression> expression) { static const char predefined_cuda_functions[] = R"(
std::stringstream ret;
ret << "fused_kernel";
for (size_t i = 0; i < expression.size(); i++) {
ret << expression[i].GetOutputId();
}
return ret.str();
}
// we get the parameter list code for the expression information
static std::string EmitDeclarationCode(
std::vector<OperationExpression> expression, std::string type) {
std::stringstream ret;
std::set<int> input_ids;
std::set<int> output_ids;
for (size_t i = 0; i < expression.size(); i++) {
std::vector<int> tmp_input = expression[i].GetInputIds();
for (size_t j = 0; j < tmp_input.size(); j++) {
int id = tmp_input[j];
input_ids.insert(id);
}
int tmp_output = expression[i].GetOutputId();
output_ids.insert(tmp_output);
}
std::set<int>::iterator it = input_ids.begin();
while (it != input_ids.end()) {
int var_index = *it;
if (output_ids.find(var_index) != output_ids.end()) {
input_ids.erase(it++);
} else {
it++;
}
}
ret << "int N, ";
for (it = input_ids.begin(); it != input_ids.end(); it++) {
int var_index = *it;
ret << type << R"(* var)" << var_index;
ret << ", ";
}
size_t count_index = 0;
for (it = output_ids.begin(); it != output_ids.end(); it++) {
int var_index = *it;
ret << type << R"(* var)" << var_index;
if (count_index != output_ids.size() - 1) {
ret << ", ";
}
count_index++;
}
return ret.str();
}
static std::string EmitComputeCode(
std::vector<OperationExpression> expression) {
// get the right experssion code using suffix expression
std::stringstream ret;
for (size_t i = 0; i < expression.size(); i++) {
ret << expression[i].GetExpression();
}
return ret.str();
}
static const char kernel_function[] = R"(
__device__ float real_exp(float x) { return ::expf(x); } __device__ float real_exp(float x) { return ::expf(x); }
__device__ double real_exp(double x) { return ::exp(x); } __device__ double real_exp(double x) { return ::exp(x); }
__device__ float real_log(float x) { return ::logf(x); } __device__ float real_log(float x) { return ::logf(x); }
__device__ double real_log(double x) { return ::log(x); } __device__ double real_log(double x) { return ::log(x); }
__device__ float real_min(float x, float y) { return ::fminf(x, y); } __device__ float real_min(float x, float y) { return ::fminf(x, y); }
__device__ double real_min(double x, double y) { return ::fmin(x, y); } __device__ double real_min(double x, double y) { return ::fmin(x, y); }
__device__ float real_max(float x, float y) { return ::fmaxf(x, y); } __device__ float real_max(float x, float y) { return ::fmaxf(x, y); }
__device__ double real_max(double x, double y) { return ::fmax(x, y); } __device__ double real_max(double x, double y) { return ::fmax(x, y); }
)"; )";
static const char kernel_elementwise_template[] = R"( static const char elementwise_cuda_template[] = R"(
extern "C" __global__ void $name($parameter){ extern "C" __global__ void $func_name($parameters) {
for(int idx = blockIdx.x * blockDim.x + threadIdx.x; for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < N; idx < N;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
$compute $compute_body
} }
} }
)"; )";
} // namespace fusion_group
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -17,125 +17,176 @@ limitations under the License. */ ...@@ -17,125 +17,176 @@ limitations under the License. */
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h" #include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/device_code.h" #include "paddle/fluid/platform/device_code.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
TEST(code_generator, cuda) { namespace fusion_group = paddle::framework::ir::fusion_group;
std::vector<int> mul_input{1, 2};
std::vector<int> add_input{3, 4}; template <typename T>
std::vector<int> sub_input{5, 6}; void SetupRandomCPUTensor(paddle::framework::LoDTensor* tensor) {
std::vector<int> relu_input{7}; static unsigned int seed = 100;
std::vector<int> sigmoid_input{8}; std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
int mul_out = 3;
int add_out = 5; T* ptr = tensor->data<T>();
int sub_out = 7; PADDLE_ENFORCE_NOT_NULL(
int relu_out = 8; ptr, "Call mutable_data to alloc memory for Tensor first.");
int sigmoid_out = 9; for (int64_t i = 0; i < tensor->numel(); ++i) {
ptr[i] = static_cast<T>(uniform_dist(rng)) - static_cast<T>(0.5);
std::string op1 = "elementwise_mul"; }
std::string op2 = "elementwise_add"; }
std::string op3 = "elementwise_sub";
std::string op4 = "relu"; void TestMain(std::string func_name,
std::string op5 = "sigmoid"; std::vector<fusion_group::OperationExpression> expressions,
paddle::framework::ir::OperationExpression opexp1(mul_input, mul_out, op1); std::vector<paddle::framework::LoDTensor> cpu_tensors, int n,
paddle::framework::ir::OperationExpression opexp2(add_input, add_out, op2); std::vector<int> input_ids, std::vector<int> output_ids) {
paddle::framework::ir::OperationExpression opexp3(sub_input, sub_out, op3); fusion_group::OperationMap::Init();
paddle::framework::ir::OperationExpression opexp4(relu_input, relu_out, op4); fusion_group::CodeGenerator code_generator;
paddle::framework::ir::OperationExpression opexp5(sigmoid_input, sigmoid_out, std::string code_str = code_generator.GenerateCode(func_name, expressions);
op5); VLOG(3) << code_str;
std::vector<paddle::framework::ir::OperationExpression> fused_op = {
opexp1, opexp2, opexp3, opexp4, opexp5};
paddle::framework::ir::CodeTemplate code_template(
paddle::framework::ir::kernel_elementwise_template);
paddle::framework::ir::CodeGenerator codegen(code_template);
paddle::framework::ir::TemplateVariable template_var;
template_var.Add("$name", EmitUniqueName(fused_op));
template_var.Add("$parameter", EmitDeclarationCode(fused_op, "float"));
template_var.Add("$compute", EmitComputeCode(fused_op));
std::string saxpy_code = codegen.GenerateCode(template_var);
std::cout << saxpy_code << std::endl;
paddle::framework::InitDevices(false, {0}); paddle::framework::InitDevices(false, {0});
paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0); paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceCode code(place, EmitUniqueName(fused_op), paddle::platform::CUDADeviceCode device_code(place, func_name, code_str);
saxpy_code); device_code.Compile();
paddle::framework::Tensor cpu_a; std::vector<paddle::framework::LoDTensor> gpu_tensors(cpu_tensors.size());
paddle::framework::Tensor cpu_b;
paddle::framework::Tensor cpu_c; std::vector<float*> gpu_ptrs(gpu_tensors.size());
paddle::framework::Tensor cpu_d; std::vector<void*> args;
paddle::framework::Tensor cpu_e; args.push_back(&n);
paddle::framework::Tensor cpu_f;
paddle::framework::Tensor cpu_g; for (size_t i = 0; i < input_ids.size(); ++i) {
paddle::framework::Tensor cpu_h; gpu_ptrs[input_ids[i]] = gpu_tensors[input_ids[i]].mutable_data<float>(
paddle::framework::Tensor cpu_o; cpu_tensors[input_ids[i]].dims(), place);
args.push_back(&gpu_ptrs[input_ids[i]]);
SetupRandomCPUTensor<float>(&cpu_tensors[input_ids[i]]);
TensorCopySync(cpu_tensors[input_ids[i]], place,
&gpu_tensors[input_ids[i]]);
}
for (size_t i = 0; i < output_ids.size(); ++i) {
gpu_ptrs[output_ids[i]] = gpu_tensors[output_ids[i]].mutable_data<float>(
cpu_tensors[output_ids[i]].dims(), place);
args.push_back(&gpu_ptrs[output_ids[i]]);
}
device_code.SetNumThreads(1024);
device_code.SetWorkloadPerThread(1);
device_code.Launch(n, &args);
auto* dev_ctx = reinterpret_cast<paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place));
dev_ctx->Wait();
for (size_t i = 0; i < output_ids.size(); ++i) {
TensorCopySync(gpu_tensors[output_ids[i]], paddle::platform::CPUPlace(),
&cpu_tensors[output_ids[i]]);
}
}
TEST(code_generator, elementwise) {
// t2 = t0 * t1
// t4 = t2 + t3
// t6 = t4 - t5
// t7 = relu(t6)
// t8 = sigmoid(t7)
fusion_group::OperationExpression exp1("elementwise_mul", {0, 1}, {2});
fusion_group::OperationExpression exp2("elementwise_add", {2, 3}, {4});
fusion_group::OperationExpression exp3("elementwise_sub", {4, 5}, {6});
fusion_group::OperationExpression exp4("relu", {6}, {7});
fusion_group::OperationExpression exp5("sigmoid", {7}, {8});
std::vector<fusion_group::OperationExpression> expressions = {
exp1, exp2, exp3, exp4, exp5};
// Prepare CPU tensors
std::vector<paddle::framework::LoDTensor> cpu_tensors(9);
std::vector<int> input_ids = {0, 1, 3, 5};
std::vector<int> output_ids = {2, 4, 6, 7, 8};
auto dims = paddle::framework::make_ddim( auto dims = paddle::framework::make_ddim(
{static_cast<int64_t>(256), static_cast<int64_t>(1024)}); {static_cast<int64_t>(256), static_cast<int64_t>(1024)});
cpu_a.mutable_data<float>(dims, paddle::platform::CPUPlace()); for (size_t i = 0; i < cpu_tensors.size(); ++i) {
cpu_b.mutable_data<float>(dims, paddle::platform::CPUPlace()); cpu_tensors[i].mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_c.mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_d.mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_e.mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_f.mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_g.mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_o.mutable_data<float>(dims, paddle::platform::CPUPlace());
size_t n = cpu_a.numel();
for (size_t i = 0; i < n; ++i) {
cpu_a.data<float>()[i] = static_cast<float>(i);
} }
for (size_t i = 0; i < n; ++i) {
cpu_b.data<float>()[i] = static_cast<float>(0.5); int n = cpu_tensors[0].numel();
cpu_d.data<float>()[i] = static_cast<float>(10.0); TestMain("fused_elementwise_0", expressions, cpu_tensors, n, input_ids,
cpu_f.data<float>()[i] = static_cast<float>(0.0); output_ids);
auto cpu_kernel_handler = [&](float* var0, float* var1, float* var3,
float* var5, int i) -> float {
float var2_i = var0[i] * var1[i];
float var4_i = var2_i + var3[i];
float var6_i = var4_i - var5[i];
float var7_i = var6_i > 0.0 ? var6_i : 0.0;
float var8_i = 1.0 / (1.0 + std::exp(-var7_i));
return var8_i;
};
// Check the results
for (int i = 0; i < n; i++) {
float result = cpu_kernel_handler(
cpu_tensors[0].data<float>(), cpu_tensors[1].data<float>(),
cpu_tensors[3].data<float>(), cpu_tensors[5].data<float>(), i);
PADDLE_ENFORCE_LT(fabs(cpu_tensors[8].data<float>()[i] - result), 1.E-05);
}
}
TEST(code_generator, elementwise_grad) {
// The var order: t0, t1, t2, t3, t0', t1', t2', t3'
// t2 = t0 * t1
// t3 = relu(t2)
// t2' = relu_grad(t2, t3, t3')
// t0', t1' = elementwise_mul_grad(t0, t1, t2, t2')
fusion_group::OperationExpression exp1("relu_grad", {2, 3, 7}, {6});
fusion_group::OperationExpression exp2("elementwise_mul_grad", {0, 1, 2, 6},
{4, 5});
std::vector<fusion_group::OperationExpression> expressions = {exp1, exp2};
// Prepare CPU tensors
std::vector<paddle::framework::LoDTensor> cpu_tensors(8);
std::vector<int> input_ids = {0, 1, 2, 3, 7};
std::vector<int> output_ids = {4, 5, 6};
auto dims = paddle::framework::make_ddim(
{static_cast<int64_t>(256), static_cast<int64_t>(1024)});
for (size_t i = 0; i < cpu_tensors.size(); ++i) {
cpu_tensors[i].mutable_data<float>(dims, paddle::platform::CPUPlace());
} }
paddle::framework::Tensor a; int n = cpu_tensors[0].numel();
paddle::framework::Tensor b; TestMain("fused_elementwise_grad_0", expressions, cpu_tensors, n, input_ids,
paddle::framework::Tensor c; output_ids);
paddle::framework::Tensor d;
paddle::framework::Tensor e; auto cpu_kernel_handler = [&](float* var0, float* var1, float* var2,
paddle::framework::Tensor f; float* var3, float* var7,
paddle::framework::Tensor g; int i) -> std::vector<float> {
paddle::framework::Tensor h; float var6_i = var2[i] > 0 ? var7[i] : 0;
paddle::framework::Tensor o; float var4_i = var6_i * var1[i];
float var5_i = var6_i * var0[i];
float* a_data = a.mutable_data<float>(dims, place); return std::vector<float>{var4_i, var5_i, var6_i};
float* b_data = b.mutable_data<float>(dims, place); };
float* c_data = c.mutable_data<float>(dims, place);
float* d_data = d.mutable_data<float>(dims, place); // Check the results
float* e_data = e.mutable_data<float>(dims, place); for (int i = 0; i < n; i++) {
float* f_data = f.mutable_data<float>(dims, place); std::vector<float> results = cpu_kernel_handler(
float* g_data = g.mutable_data<float>(dims, place); cpu_tensors[0].data<float>(), cpu_tensors[1].data<float>(),
float* h_data = h.mutable_data<float>(dims, place); cpu_tensors[2].data<float>(), cpu_tensors[3].data<float>(),
float* o_data = o.mutable_data<float>(dims, place); cpu_tensors[7].data<float>(), i);
PADDLE_ENFORCE_LT(fabs(cpu_tensors[4].data<float>()[i] - results[0]),
TensorCopySync(cpu_a, place, &a); 1.E-05);
TensorCopySync(cpu_b, place, &b); PADDLE_ENFORCE_LT(fabs(cpu_tensors[5].data<float>()[i] - results[1]),
TensorCopySync(cpu_d, place, &d); 1.E-05);
TensorCopySync(cpu_f, place, &f); PADDLE_ENFORCE_LT(fabs(cpu_tensors[6].data<float>()[i] - results[2]),
1.E-05);
code.Compile();
std::vector<void*> args = {&n, &a_data, &b_data, &d_data, &f_data,
&c_data, &e_data, &g_data, &h_data, &o_data};
code.SetNumThreads(1024);
code.SetWorkloadPerThread(1);
code.Launch(n, &args);
TensorCopySync(o, paddle::platform::CPUPlace(), &cpu_o);
for (size_t i = 0; i < n; i++) {
float result =
(1.0 / (1.0 + std::exp(-std::max(
0.0, static_cast<float>(i) * 0.5 + 10.0 - 0.0))));
PADDLE_ENFORCE_EQ(cpu_o.data<float>()[i], result);
} }
} }
#endif #endif
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h" #include "paddle/fluid/framework/ir/fusion_group/elementwise_group_detector.h"
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle { namespace paddle {
...@@ -20,12 +21,22 @@ namespace framework { ...@@ -20,12 +21,22 @@ namespace framework {
namespace ir { namespace ir {
namespace fusion_group { namespace fusion_group {
static std::unordered_set<std::string> binary_op_types = { static std::unordered_set<std::string> binary_op_types;
"elementwise_add", "elementwise_sub", "elementwise_mul", static std::unordered_set<std::string> unary_op_types;
"elementwise_div", "elementwise_min", "elementwise_max"};
static std::unordered_set<std::string> unary_op_types = {"relu", "sigmoid", static std::unordered_set<std::string>& GetBinaryOpTypes() {
"tanh"}; if (binary_op_types.empty()) {
binary_op_types = OperationMap::Instance().Find(0, 2);
}
return binary_op_types;
}
static std::unordered_set<std::string>& GetUnaryOpTypes() {
if (unary_op_types.empty()) {
unary_op_types = OperationMap::Instance().Find(0, 1);
}
return unary_op_types;
}
static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types, static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
Node* n) { Node* n) {
...@@ -39,7 +50,7 @@ static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types, ...@@ -39,7 +50,7 @@ static bool IsSpecifiedOp(const std::unordered_set<std::string>& op_types,
} }
static bool IsBinaryOp(Node* n) { static bool IsBinaryOp(Node* n) {
if (IsSpecifiedOp(binary_op_types, n) && n->inputs.size() == 2U) { if (IsSpecifiedOp(GetBinaryOpTypes(), n) && n->inputs.size() == 2U) {
auto* x = n->inputs[0]; auto* x = n->inputs[0];
auto* y = n->inputs[1]; auto* y = n->inputs[1];
...@@ -64,7 +75,7 @@ static bool IsBinaryOp(Node* n) { ...@@ -64,7 +75,7 @@ static bool IsBinaryOp(Node* n) {
return false; return false;
} }
static bool IsUnaryOp(Node* n) { return IsSpecifiedOp(unary_op_types, n); } static bool IsUnaryOp(Node* n) { return IsSpecifiedOp(GetUnaryOpTypes(), n); }
bool ElementwiseGroupDetector::IsElementwiseOp(Node* n) { bool ElementwiseGroupDetector::IsElementwiseOp(Node* n) {
return IsBinaryOp(n) || IsUnaryOp(n); return IsBinaryOp(n) || IsUnaryOp(n);
......
...@@ -25,7 +25,7 @@ namespace framework { ...@@ -25,7 +25,7 @@ namespace framework {
namespace ir { namespace ir {
namespace fusion_group { namespace fusion_group {
struct ElementwiseGroupDetector { class ElementwiseGroupDetector {
public: public:
int operator()(Node* n); int operator()(Node* n);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h" #include "paddle/fluid/framework/ir/fusion_group/fusion_group_pass.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle { namespace paddle {
...@@ -22,6 +23,8 @@ namespace framework { ...@@ -22,6 +23,8 @@ namespace framework {
namespace ir { namespace ir {
TEST(FusionGroupPass, elementwise_list) { TEST(FusionGroupPass, elementwise_list) {
fusion_group::OperationMap::Init();
// inputs operator output // inputs operator output
// -------------------------------------------------------- // --------------------------------------------------------
// (x, y) mul -> tmp_0 // (x, y) mul -> tmp_0
...@@ -69,6 +72,8 @@ TEST(FusionGroupPass, elementwise_list) { ...@@ -69,6 +72,8 @@ TEST(FusionGroupPass, elementwise_list) {
} }
TEST(FusionGroupPass, elementwise_tree) { TEST(FusionGroupPass, elementwise_tree) {
fusion_group::OperationMap::Init();
// inputs operator output // inputs operator output
// -------------------------------------------------------- // --------------------------------------------------------
// (x0, y0) mul -> tmp_0 // (x0, y0) mul -> tmp_0
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
namespace paddle {
namespace framework {
namespace ir {
namespace fusion_group {
OperationMap* OperationMap::map = nullptr;
OperationMap::OperationMap() {
InsertUnaryElementwiseOperations();
InsertBinaryElementwiseOperations();
}
std::unordered_set<std::string> OperationMap::Find(int type, int num_operands) {
std::unordered_set<std::string> res;
for (auto& t : operations_) {
if ((t.second.type == type) &&
(num_operands < 0 || t.second.num_operands == num_operands)) {
res.insert(t.first);
}
}
return res;
}
void OperationMap::Insert(int type, int num_operands, std::string op_type,
std::string expr,
std::vector<std::string> grad_exprs) {
Operation op(type, num_operands, op_type, {expr});
PADDLE_ENFORCE_EQ(op.IsValid(), true, "Operation %s is invalid.", op_type);
operations_[op_type] = op;
if (grad_exprs.size() > 0U) {
std::string grad_op_type = op_type + "_grad";
Operation grad_op(type, num_operands, grad_op_type, grad_exprs);
PADDLE_ENFORCE_EQ(grad_op.IsValid(), true, "Operation %s is invalid.",
grad_op_type);
operations_[grad_op_type] = grad_op;
}
}
void OperationMap::InsertUnaryElementwiseOperations() {
int type = 0;
int num_oprands = 1;
// For unary elementwise operations:
// ${0} - x
// ${1} - out
// ${2} - dout
// relu:
// out = f(x) = x > 0 ? x : 0
// dx = dout * (out > 0 ? 1 : 0) = dout * (x > 0 ? 1 : 0)
Insert(type, num_oprands, "relu", "real_max(${0}, 0)",
{"${0} > 0 ? ${2} : 0"});
// sigmoid:
// out = f(x) = 1.0 / (1.0 + exp(-x))
// dx = dout * out * (1 - out)
Insert(type, num_oprands, "sigmoid", "1.0 / (1.0 + real_exp(- ${0}))",
{"${2} * ${1} * (1.0 - ${1})"});
// tanh:
// out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0;
// dx = dout * (1 - out * out)
Insert(type, num_oprands, "tanh", "2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0",
{"${2} * (1.0 - ${1} * ${1})"});
}
void OperationMap::InsertBinaryElementwiseOperations() {
int type = 0;
int num_oprands = 2;
// For binary elementwise oprations:
// ${0} - x
// ${1} - y
// ${2} - out
// ${3} - dout
// elementwise_add:
// out = x + y
// dx = dout * 1
// dy = dout * 1
Insert(type, num_oprands, "elementwise_add", "${0} + ${1}", {"${3}", "${3}"});
// elementwise_sub:
// out = x - y
// dx = dout * 1
// dy = dout * (-1)
Insert(type, num_oprands, "elementwise_sub", "${0} - ${1}",
{"${3}", "- ${3}"});
// elementwise_mul:
// out = x * y
// dx = dout * y
// dy = dout * x
Insert(type, num_oprands, "elementwise_mul", "${0} * ${1}",
{"${3} * ${1}", "${3} * ${0}"});
Insert(type, num_oprands, "elementwise_div", "${0} / ${1}", {});
Insert(type, num_oprands, "elementwise_min", "real_min(${0}, ${1})", {});
Insert(type, num_oprands, "elementwise_max", "real_max(${0}, ${1})", {});
}
} // namespace fusion_group
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
namespace fusion_group {
struct Operation {
Operation() {}
Operation(int t, int n, std::string o, std::vector<std::string> e)
: type(t), num_operands(n), op_type(o), exprs(e) {}
bool IsGradOp() {
std::string suffix = "_grad";
return op_type.rfind(suffix) == (op_type.length() - suffix.length());
}
bool IsValid() {
if (!IsGradOp() && exprs.size() != 1U) {
return false;
}
if (IsGradOp() && exprs.size() != static_cast<size_t>(num_operands)) {
return false;
}
return true;
}
int type;
int num_operands;
std::string op_type;
std::vector<std::string> exprs;
};
class OperationMap {
public:
OperationMap();
static OperationMap& Instance() {
PADDLE_ENFORCE_NOT_NULL(map, "Need to initialize OperationMap first!");
return *map;
}
static OperationMap& Init() {
if (map == nullptr) {
map = new OperationMap();
}
return *map;
}
std::unordered_set<std::string> Find(int type, int num_operands = -1);
bool Has(std::string op_type) {
return operations_.find(op_type) != operations_.end();
}
Operation& Get(std::string op_type) {
auto iter = operations_.find(op_type);
PADDLE_ENFORCE_NE(iter, operations_.end(),
"Operation %s is not supported yet.", op_type);
return iter->second;
}
private:
void Insert(int type, int num_operands, std::string op_type, std::string expr,
std::vector<std::string> grad_exprs);
void InsertUnaryElementwiseOperations();
void InsertBinaryElementwiseOperations();
private:
static OperationMap* map;
std::unordered_map<std::string, Operation> operations_;
DISABLE_COPY_AND_ASSIGN(OperationMap);
};
} // namespace fusion_group
} // namespace ir
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册