未验证 提交 0a963ee9 编写于 作者: C CtfGo 提交者: GitHub

add cinn_launch_op for using CINN to optimize graph (#36600)

增加CinnLaunchOp,负责执行Cinn子图编译的结果,要点如下:
1. 在子图划分的BuildCinnPass中,每个子图在原图中会被替换为该CinnLaunchOp,由它来调用Cinn进行子图编译、执行的功能。
2. CinnLaunchOp的输入/输出即为子图的输入和输出,另外增加`compilation_key`属性,它可由该属性key从全局Cache中获取子图对象、编译结果,该属性由BuildCinnPass在创建Op时进行设置
3. CinnLaunchOp功能实现的流程为:
        - 从全局Cache中获取子图对象
        - 从全局Cache中获取子图编译结果,未命中cache时进行即时编译
        - 根据编译结果的变量信息(数据类型、shape)初始化运行时数据,分配内存/显存
        - 将运行时数据打包为参数,调用cinn的可执行对象runtime program进行计算
        - 子图运行结果通过参数指针同步到paddle侧的tensor
上级 8937205b
......@@ -112,12 +112,15 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
<< cinn_graph->Visualize();
ApplyPass(cinn_graph.get(), "OpFusion");
auto scope = BuildScope(target, cinn_graph);
GraphCompiler graph_compiler(target, scope, cinn_graph);
auto graph_compiler =
std::make_unique<GraphCompiler>(target, scope, cinn_graph);
GraphCompiler::CompileOptions options;
options.with_instantiate_variables = false;
auto compiled_res = graph_compiler.Build(options);
auto compiled_res = graph_compiler->Build(options);
auto compiled_obj = std::make_unique<CinnCompiledObject>();
*compiled_obj = {std::move(compiled_res.runtime_program), scope,
*compiled_obj = {std::move(graph_compiler),
std::move(compiled_res.runtime_program), scope,
symbol.var_model_to_program_map()};
return compiled_obj;
}
......
......@@ -33,6 +33,7 @@ namespace framework {
namespace paddle2cinn {
struct CinnCompiledObject {
std::unique_ptr<::cinn::hlir::framework::GraphCompiler> compiler;
std::unique_ptr<::cinn::hlir::framework::Program> runtime_program;
std::shared_ptr<::cinn::hlir::framework::Scope> scope;
std::unordered_map<std::string, std::string> paddle2cinn_varmap;
......
......@@ -79,8 +79,8 @@ if(WITH_UNITY_BUILD)
include(unity_build_rule.cmake)
endif()
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op cinn_launch_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})
op_library(save_combine_op DEPS string_array)
......@@ -166,6 +166,15 @@ if (WITH_ASCEND_CL)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} npu_op_runner)
endif()
if (WITH_CINN)
cc_library(cinn_launch_op_helper SRCS cinn_launch_op_helper.cc DEPS operator cinn)
cc_test(cinn_launch_op_helper_test SRCS cinn_launch_op_helper_test.cc DEPS cinn_launch_op_helper)
op_library(cinn_launch_op SRCS cinn_launch_op.cc cinn_launch_op.cu.cc DEPS cinn_compiler cinn_launch_op_helper cinn ${OP_HEADER_DEPS})
if (WITH_GPU)
nv_test(cinn_launch_op_test SRCS cinn_launch_op_test.cc DEPS cinn_compiler cinn_launch_op elementwise_add_op)
endif()
endif()
# FIXME(typhoonzero): operator deps may not needed.
# op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)
# op_library(array_to_lod_tensor_op DEPS lod_rank_table_op)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/cinn_launch_op.h"
namespace paddle {
namespace operators {
class CinnLaunchOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnLaunchOp");
OP_INOUT_CHECK(ctx->HasOutput(kOutputs), "Output", kOutputs,
"CinnLaunchOp");
}
protected:
/* [Why use single type kernel]:
*
* This op is similar to a control flow op, it doses not need
* a op kernel, but in order to make it execute under dynamic
* graph mode, implement it with op kernel.
*
* So whether the kernel data type is int, float or other type,
* which has no effect on its execution logic, so directly
* specified a data type here.
*
* Of course, the data type here is also not important.
*/
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
};
class CinnLaunchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(kX,
"(vector<LoDTensor>)"
"which are the input of graph inside the CinnLaunchOp.")
.AsDuplicable();
AddOutput(kOutputs,
"(vector<LoDTensor>)"
"which are the output of graph inside the CinnLaunchOp.")
.AsDuplicable();
AddAttr<std::string>(
kCompilationKey,
"(string)"
"a hash key used to get the graph object or its computation result.");
AddComment(R"DOC(
CinnLaunch Operator.
This operator is used to launch CINN(https://github.com/PaddlePaddle/CINN/blob/develop/README.md)
to compile a graph and execute the compiled object.
Both input and output of this operator are a set of variables
which are input and output of the graph respectively that will be
compiled and executed in this operator.
In addition, there is an attribute named 'compilation_key' should be
set necessarily to get corresponding ir::Graph object of the graph
or its computation result.
It accomplishs the computation of graph following several steps:
1. Fetch ir::Graph object from CinnCompiler using kCompilationKey
2. Compile the graph to a compiled object, and insert it to the
global cache so that we can directly query it from this cache next time
when shape of input variables are not changed at all.
3. Create and instantiate all variables used to execute compiled runtime program
if necessary according to the info(type,shape) included in the return scope.
4. Pack each tensor buffer of all above variables as execution arguments.
5. Launch execution of the runtime program with above arguments, then
the result would be output by writing value on underlying buffer address.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
cinn_launch, ops::CinnLaunchOp, ops::CinnLaunchOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
/* see [Why use single type kernel] */
REGISTER_OP_CPU_KERNEL(
cinn_launch,
ops::CinnLaunchOpKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cinn_launch_op.h"
/* see [Why use single type kernel] */
REGISTER_OP_CUDA_KERNEL(cinn_launch,
paddle::operators::CinnLaunchOpKernel<
paddle::platform::CUDADeviceContext, float>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/scope.h"
#include "cinn/runtime/cinn_runtime.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn_launch_op_helper.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace operators {
static constexpr char kX[] = "X";
static constexpr char kOutputs[] = "Out";
static constexpr char kCompilationKey[] = "compilation_key";
using LoDTensor = framework::LoDTensor;
using Name2ConstTensor = std::map<std::string, const LoDTensor*>;
using CinnTensor = cinn::hlir::framework::Tensor;
using Name2CinnTensor = std::unordered_map<std::string, CinnTensor>;
using framework::paddle2cinn::CinnCompiler;
template <typename DeviceContext, typename T>
class CinnLaunchOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// Step 1. Find graph object and prepare input
PADDLE_ENFORCE_EQ(ctx.HasAttr(kCompilationKey), true,
platform::errors::NotFound(
"No Attribute(%s) found for CinnLaunchOp operator.",
kCompilationKey));
const auto& compilation_key =
ctx.template Attr<std::string>(kCompilationKey);
VLOG(2) << "CinnLaunchOp compilation_key:" << compilation_key;
const auto& graph = CinnCompiler::GetInstance()->FindGraph(compilation_key);
auto input_variable_names = ctx.InputNames(kX);
Name2ConstTensor input_tensors =
details::GetConstTensors(ctx.scope(), input_variable_names);
// Step 2. Get compilation result of the graph
auto target = details::PlaceToCinnTarget(ctx.GetPlace());
const auto& cinn_compiled_object =
CinnCompiler::GetInstance()->Compile(graph, input_tensors, target);
VLOG(2) << "CinnLaunchOp compile graph done on " << ctx.GetPlace();
const auto& cinn_runtime_program = cinn_compiled_object.runtime_program;
const auto& compiled_scope = *(cinn_compiled_object.scope.get());
const auto& paddle2cinn_varmap = cinn_compiled_object.paddle2cinn_varmap;
// Step 3. Initialize all variables of the compilation runtime program
// in paddle, and pack them into execution arguments
VLOG(2) << "CinnLaunchOp prepare execution arguments";
std::map<std::string, cinn_pod_value_t> name2argument;
std::vector<std::unique_ptr<cinn_buffer_t>> hold_buffers;
// prepare input variables
Name2CinnTensor input_compiled_tensors = details::GetCompiledTensors(
input_variable_names, compiled_scope, paddle2cinn_varmap);
details::CheckTensorEquivalent(input_tensors, input_compiled_tensors);
details::AppendExecutionArguments(ctx.scope(), input_variable_names,
paddle2cinn_varmap, &name2argument,
&hold_buffers);
// prepare output variables
auto output_variable_names = ctx.OutputNames(kOutputs);
Name2CinnTensor output_compiled_tensors = details::GetCompiledTensors(
output_variable_names, compiled_scope, paddle2cinn_varmap);
details::InitializeOutputVar(ctx.scope(), ctx.GetPlace(),
output_compiled_tensors);
Name2ConstTensor output_tensors =
details::GetConstTensors(ctx.scope(), output_variable_names);
details::CheckTensorEquivalent(output_tensors, output_compiled_tensors);
details::AppendExecutionArguments(ctx.scope(), output_variable_names,
paddle2cinn_varmap, &name2argument,
&hold_buffers);
// prepare temporary variables
auto temp_variable_names =
details::SeperateTempVar(compiled_scope, paddle2cinn_varmap,
input_variable_names, output_variable_names);
auto temp_scope = ctx.scope().NewTmpScope();
if (!temp_variable_names.empty()) {
details::InitializeTempVar(temp_variable_names, compiled_scope,
ctx.GetPlace(), temp_scope.get());
details::AppendExecutionArguments(*temp_scope, temp_variable_names,
paddle2cinn_varmap, &name2argument,
&hold_buffers);
}
// Step 4. Launch CINN to execute the compilation runtime program
cinn_runtime_program->Execute(&name2argument);
VLOG(2) << "CinnLaunchOp launch runtime_program execution done.";
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/cinn_launch_op_helper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace details {
using LoDTensor = framework::LoDTensor;
using Scope = framework::Scope;
using Name2ConstTensor = std::map<std::string, const LoDTensor*>;
using CinnTensor = cinn::hlir::framework::Tensor;
using CinnScope = cinn::hlir::framework::Scope;
using Name2CinnTensor = std::unordered_map<std::string, CinnTensor>;
const cinn::common::Target& PlaceToCinnTarget(const platform::Place& place) {
if (platform::is_cpu_place(place)) {
return cinn::common::DefaultHostTarget();
} else if (platform::is_gpu_place(place)) {
return cinn::common::DefaultNVGPUTarget();
}
PADDLE_THROW(platform::errors::InvalidArgument(
"CINN is not supported on current place:%s", place));
return cinn::common::UnkTarget();
}
Name2ConstTensor GetConstTensors(
const Scope& scope, const std::vector<std::string>& variable_names) {
Name2ConstTensor name2tensor;
for (const auto& var_name : variable_names) {
auto* var_ptr = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var_ptr, platform::errors::NotFound("Variable(%s) not found in Scope.",
var_name));
PADDLE_ENFORCE_EQ(var_ptr->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"Variable(%s) is not LoDTensor that is "
"the only supported by compiler now.",
var_name));
name2tensor.emplace(var_name, &var_ptr->Get<framework::LoDTensor>());
}
return name2tensor;
}
Name2CinnTensor GetCompiledTensors(
const std::vector<std::string>& paddle_var_names,
const CinnScope& compiled_scope,
const std::unordered_map<std::string, std::string>& paddle2cinn_varmap) {
Name2CinnTensor name2tensor;
for (const auto& pd_name : paddle_var_names) {
PADDLE_ENFORCE_GT(paddle2cinn_varmap.count(pd_name), 0,
platform::errors::NotFound(
"the corresponding compiled one of variable(%s) "
"not found in compilation result.",
pd_name));
const auto& cinn_name = paddle2cinn_varmap.at(pd_name);
PADDLE_ENFORCE_NOT_NULL(
compiled_scope.FindVar(cinn_name),
platform::errors::NotFound("Variable(%s) not found in compiled scope.",
pd_name));
name2tensor.emplace(pd_name, compiled_scope.GetTensor(cinn_name));
}
return name2tensor;
}
void CheckTensorEquivalent(const Name2ConstTensor& paddle_tensors,
const Name2CinnTensor& compiled_tensors) {
for (const auto& name2tensor : paddle_tensors) {
const auto& pd_name = name2tensor.first;
const auto* paddle_tensor = name2tensor.second;
PADDLE_ENFORCE_EQ(
paddle_tensor->IsInitialized(), true,
platform::errors::InvalidArgument(
"The tensor in variable(%s) is not initialized.", pd_name));
PADDLE_ENFORCE_GT(compiled_tensors.count(pd_name), 0,
platform::errors::NotFound(
"the corresponding compiled tensor of variable(%s) "
"not found in compilation result.",
pd_name));
const auto& cinn_tensor = compiled_tensors.at(pd_name);
auto compiled_dim = framework::make_ddim(cinn_tensor->shape().data());
PADDLE_ENFORCE_EQ(paddle_tensor->dims(), compiled_dim,
platform::errors::InvalidArgument(
"The tensor dimension in variable(%s) "
"is not equivalent, paddle is [%s] "
"but compiled result is [%s].",
pd_name, paddle_tensor->dims(), compiled_dim));
// TODO(CtfGo): check the underlying data type is equivalent
}
}
void InitializeOutputVar(const Scope& scope, const platform::Place& place,
const Name2CinnTensor& compiled_tensors) {
for (const auto& name2tensor : compiled_tensors) {
const auto& pd_name = name2tensor.first;
const auto& cinn_tensor = name2tensor.second;
auto* var_ptr = scope.FindVar(pd_name);
PADDLE_ENFORCE_NOT_NULL(
var_ptr, platform::errors::NotFound("Variable(%s) not found in scope.",
pd_name));
auto* paddle_tensor = var_ptr->GetMutable<LoDTensor>();
if (!paddle_tensor->IsInitialized()) {
// TODO(CtfGo): support mutable corresponding c++ type with the
// compilation type
paddle_tensor->mutable_data<float>(
framework::make_ddim(cinn_tensor->shape().data()), place);
VLOG(2) << "Variable(" << pd_name
<< ") is initialized using compilation result, type:"
<< paddle_tensor->type() << ", dims:" << paddle_tensor->dims();
}
}
}
std::vector<std::string> SeperateTempVar(
const CinnScope& compiled_scope,
const std::unordered_map<std::string, std::string>& paddle2cinn_varmap,
const std::vector<std::string>& input_var_names,
const std::vector<std::string>& output_var_names) {
std::unordered_set<std::string> all_paddle_names, all_cinn_names;
for_each(paddle2cinn_varmap.begin(), paddle2cinn_varmap.end(),
[&all_paddle_names](const auto& name_pd2cinn) {
all_paddle_names.insert(name_pd2cinn.first);
});
auto cinn_names_view = compiled_scope.var_names();
for_each(cinn_names_view.begin(), cinn_names_view.end(),
[&all_cinn_names](const auto& str_view) {
all_cinn_names.emplace(str_view.data(), str_view.size());
});
auto exclude_fn = [&](const auto& pd_name) {
PADDLE_ENFORCE_EQ(all_paddle_names.erase(pd_name), 1,
platform::errors::NotFound(
"The corresponding compiled one of variable(%s) "
"not found in compilation result.",
pd_name));
PADDLE_ENFORCE_EQ(all_cinn_names.erase(paddle2cinn_varmap.at(pd_name)), 1,
platform::errors::NotFound(
"Variable(%s) not found in compiled scope", pd_name));
};
for_each(input_var_names.begin(), input_var_names.end(), exclude_fn);
for_each(output_var_names.begin(), output_var_names.end(), exclude_fn);
if (all_cinn_names.empty()) {
VLOG(2) << "No temporary variable is needed during "
"execution in cinn runtime program";
return {};
}
return {all_cinn_names.begin(), all_cinn_names.end()};
}
void InitializeTempVar(const std::vector<std::string>& variable_names,
const CinnScope& compiled_scope,
const platform::Place& place, Scope* temp_scope) {
for (const auto& var_name : variable_names) {
PADDLE_ENFORCE_NOT_NULL(
compiled_scope.FindVar(var_name),
platform::errors::NotFound(
"Temporary variable(%s) not found in compiled scope", var_name));
const auto& cinn_tensor = compiled_scope.GetTensor(var_name);
// use the same variable name defined by CINN
auto* var_ptr = temp_scope->Var(var_name);
auto* paddle_tensor = var_ptr->GetMutable<LoDTensor>();
auto compiled_ddim = framework::make_ddim(cinn_tensor->shape().data());
// TODO(CtfGo): support mutable corresponding c++ type
paddle_tensor->mutable_data<float>(compiled_ddim, place);
VLOG(2) << "Add temporary variable(" << var_name << "), dimension is ["
<< compiled_ddim << "]";
}
}
void SharePaddleTensorWithCinnBuffer(LoDTensor* paddle_tensor,
cinn_buffer_t* cinn_buffer) {
std::vector<cinn_dimension_t> cinn_dims(paddle_tensor->dims().size());
for (auto i = 0; i < cinn_dims.size(); ++i) {
cinn_dims[i] = static_cast<cinn_dimension_t>(paddle_tensor->dims().at(i));
}
cinn_buffer->resize(cinn_dims.data(), cinn_dims.size());
cinn_buffer->memory =
reinterpret_cast<uint8_t*>(paddle_tensor->data<float>());
}
void AppendExecutionArguments(
const Scope& scope, const std::vector<std::string>& variable_names,
const std::unordered_map<std::string, std::string>& paddle2cinn_varmap,
std::map<std::string, cinn_pod_value_t>* name2argument,
std::vector<std::unique_ptr<cinn_buffer_t>>* hold_buffers) {
for (const auto& pd_name : variable_names) {
auto* var_ptr = scope.FindVar(pd_name);
PADDLE_ENFORCE_NOT_NULL(
var_ptr, platform::errors::NotFound("Variable(%s) not found in Scope.",
pd_name));
auto* paddle_tensor = var_ptr->GetMutable<LoDTensor>();
// if not found a paddle variable in the map,
// which means it is a temporary variable extra added,
// so the paddle name is same with cinn
const auto& cinn_name = paddle2cinn_varmap.count(pd_name)
? paddle2cinn_varmap.at(pd_name)
: pd_name;
std::unique_ptr<cinn_buffer_t> buffer_ptr(new cinn_buffer_t());
SharePaddleTensorWithCinnBuffer(paddle_tensor, buffer_ptr.get());
name2argument->emplace(cinn_name, buffer_ptr.get());
hold_buffers->emplace_back(std::move(buffer_ptr));
}
}
} // namespace details
} // namespace operators
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include "cinn/common/target.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/scope.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace operators {
namespace details {
const cinn::common::Target& PlaceToCinnTarget(const platform::Place& place);
// Get the underlying tensor of a variable,
// result: paddle name --> const LoDTensor*
std::map<std::string, const framework::LoDTensor*> GetConstTensors(
const framework::Scope& scope,
const std::vector<std::string>& variable_names);
// Get the compiled tensor of a paddle variable,
// result: paddle name --> CinnTensor
std::unordered_map<std::string, cinn::hlir::framework::Tensor>
GetCompiledTensors(
const std::vector<std::string>& paddle_var_names,
const cinn::hlir::framework::Scope& compiled_scope,
const std::unordered_map<std::string, std::string>& paddle2cinn_varmap);
// Check a original tensor of Paddle is equivalent
// to the complied tensor from CINN
void CheckTensorEquivalent(
/*paddle name -> const LoDTensor**/
const std::map<std::string, const framework::LoDTensor*>& paddle_tensors,
/*paddle name -> CinnTensor*/
const std::unordered_map<std::string, cinn::hlir::framework::Tensor>&
compiled_tensors);
// Initialize output variables with the compilation result from CINN
void InitializeOutputVar(
const framework::Scope& scope, const platform::Place& place,
/*paddle name -> CinnTensor*/
const std::unordered_map<std::string, cinn::hlir::framework::Tensor>&
compiled_tensors);
// Extract extral temporary variables by
// excluding input/output variables from compiled scope
std::vector<std::string> SeperateTempVar(
const cinn::hlir::framework::Scope& compiled_scope,
const std::unordered_map<std::string, std::string>& paddle2cinn_varmap,
const std::vector<std::string>& input_var_names,
const std::vector<std::string>& output_var_names);
// Initialize temporary variables in a temp scope,
// using the definition in compiled_scope
void InitializeTempVar(const std::vector<std::string>& variable_names,
const cinn::hlir::framework::Scope& compiled_scope,
const platform::Place& place,
framework::Scope* temp_scope);
// Share paddle tensor to a cinn one through cinn_buffer_t object
void SharePaddleTensorWithCinnBuffer(framework::LoDTensor* paddle_tensor,
cinn_buffer_t* cinn_buffer);
// Pack tensors of all variables as execution arguments,
// which will be passed into compilation runtime program to execute
void AppendExecutionArguments(
const framework::Scope& scope,
const std::vector<std::string>& variable_names,
const std::unordered_map<std::string, std::string>& paddle2cinn_varmap,
std::map<std::string, cinn_pod_value_t>* name2argument,
std::vector<std::unique_ptr<cinn_buffer_t>>* hold_buffers);
} // namespace details
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cinn_launch_op_helper.h"
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace details {
using LoDTensor = framework::LoDTensor;
using Scope = framework::Scope;
using CinnShape = cinn::hlir::framework::Shape;
using CinnTensor = cinn::hlir::framework::Tensor;
using CinnScope = cinn::hlir::framework::Scope;
TEST(CinnLaunchOpHelperTest, TestPlaceToCinnTarget) {
ASSERT_EQ(PlaceToCinnTarget(platform::CPUPlace()),
cinn::common::DefaultHostTarget());
ASSERT_EQ(PlaceToCinnTarget(platform::CUDAPlace(0)),
cinn::common::DefaultNVGPUTarget());
}
TEST(CinnLaunchOpHelperTest, TestGetConstTensors) {
// build test data
Scope scope;
auto* var1 = scope.Var("lodtensor_var_1");
var1->GetMutable<LoDTensor>();
auto* var2 = scope.Var("lodtensor_var_2");
var2->GetMutable<LoDTensor>();
auto* var3 = scope.Var("selectedrows_var_1");
var3->GetMutable<framework::SelectedRows>();
// get expected result with legal input
auto name2tensor =
GetConstTensors(scope, {"lodtensor_var_1", "lodtensor_var_2"});
ASSERT_EQ(name2tensor.size(), 2);
EXPECT_EQ(name2tensor.at("lodtensor_var_1"), &var1->Get<LoDTensor>());
EXPECT_EQ(name2tensor.at("lodtensor_var_2"), &var2->Get<LoDTensor>());
}
TEST(CinnLaunchOpHelperTest, TestGetCompiledTensors) {
// build test data
std::unordered_map<std::string, std::string> paddle2cinn_varmap(
{{"pd_var1", "cinn_var1"},
{"pd_var2", "cinn_var2"},
{"pd_var3", "cinn_var3"}});
CinnScope compiled_scope;
compiled_scope.Var<CinnTensor>("cinn_var1");
compiled_scope.Var<CinnTensor>("cinn_var2");
// get expected result with legal input
auto name2tensor = GetCompiledTensors({"pd_var1", "pd_var2"}, compiled_scope,
paddle2cinn_varmap);
ASSERT_EQ(name2tensor.size(), 2);
EXPECT_EQ(name2tensor.at("pd_var1").get(),
compiled_scope.GetTensor("cinn_var1").get());
EXPECT_EQ(name2tensor.at("pd_var2").get(),
compiled_scope.GetTensor("cinn_var2").get());
}
TEST(CinnLaunchOpHelperTest, TestCheckTensorEquivalent) {
// build test data
platform::CPUPlace place;
Scope scope;
CinnScope compiled_scope;
auto* tensor1 = scope.Var("var1")->GetMutable<LoDTensor>();
auto dims1 = std::vector<int>({2, 3});
tensor1->mutable_data<float>(framework::make_ddim(dims1), place);
auto* tensor2 = scope.Var("var2")->GetMutable<LoDTensor>();
auto dims2 = std::vector<int>({5, 6, 7});
tensor2->mutable_data<float>(framework::make_ddim(dims2), place);
auto* tensor3 = scope.Var("var3")->GetMutable<LoDTensor>();
tensor3->mutable_data<float>(framework::make_ddim({10, 20}), place);
auto* tensor4 = scope.Var("var4")->GetMutable<LoDTensor>();
tensor4->mutable_data<float>(framework::make_ddim({2, 4, 6}), place);
compiled_scope.Var<CinnTensor>("var1");
compiled_scope.Var<CinnTensor>("var2");
compiled_scope.Var<CinnTensor>("var3");
auto compiled_tensor1 = compiled_scope.GetTensor("var1");
compiled_tensor1->Resize(CinnShape(dims1));
auto compiled_tensor2 = compiled_scope.GetTensor("var2");
compiled_tensor2->Resize(CinnShape(dims2));
auto compiled_tensor3 = compiled_scope.GetTensor("var3");
compiled_tensor3->Resize(CinnShape({10}));
// expected equality
CheckTensorEquivalent(
{{"var1", tensor1}, {"var2", tensor2}},
{{"var1", compiled_tensor1}, {"var2", compiled_tensor2}});
}
TEST(CinnLaunchOpHelperTest, TestInitializeOutputVar) {
// build test data
platform::CPUPlace place;
Scope scope;
scope.Var("var1");
scope.Var("var2");
CinnScope compiled_scope;
compiled_scope.Var<CinnTensor>("var1");
compiled_scope.Var<CinnTensor>("var2");
compiled_scope.Var<CinnTensor>("var3");
auto compiled_tensor1 = compiled_scope.GetTensor("var1");
compiled_tensor1->Resize(CinnShape({2, 3}));
auto compiled_tensor2 = compiled_scope.GetTensor("var2");
compiled_tensor2->Resize(CinnShape({5, 6, 7}));
auto compiled_tensor3 = compiled_scope.GetTensor("var3");
compiled_tensor3->Resize(CinnShape({10}));
// expected result
InitializeOutputVar(scope, place,
{{"var1", compiled_tensor1}, {"var2", compiled_tensor2}});
auto* var1 = scope.FindVar("var1");
ASSERT_TRUE(var1->IsType<LoDTensor>());
EXPECT_TRUE(var1->Get<LoDTensor>().IsInitialized());
EXPECT_EQ(var1->Get<LoDTensor>().dims(), framework::make_ddim({2, 3}));
auto* var2 = scope.FindVar("var2");
ASSERT_TRUE(var2->IsType<LoDTensor>());
EXPECT_TRUE(var2->Get<LoDTensor>().IsInitialized());
EXPECT_EQ(var2->Get<LoDTensor>().dims(), framework::make_ddim({5, 6, 7}));
}
TEST(CinnLaunchOpHelperTest, TestSeperateTempVar) {
CinnScope compiled_scope;
compiled_scope.Var<CinnTensor>("cinn_temp_var1");
compiled_scope.Var<CinnTensor>("cinn_input_var1");
compiled_scope.Var<CinnTensor>("cinn_input_var2");
compiled_scope.Var<CinnTensor>("cinn_temp_var2");
compiled_scope.Var<CinnTensor>("cinn_output_var1");
auto variable_names =
SeperateTempVar(compiled_scope, {{"input_var1", "cinn_input_var1"},
{"input_var2", "cinn_input_var2"},
{"output_var1", "cinn_output_var1"}},
{"input_var1", "input_var2"}, {"output_var1"});
ASSERT_EQ(variable_names.size(), 2);
}
TEST(CinnLaunchOpHelperTest, TestInitializeTempVar) {
// build test data
Scope temp_scope;
platform::CPUPlace place;
CinnScope compiled_scope;
compiled_scope.Var<CinnTensor>("temp_var1");
compiled_scope.Var<CinnTensor>("temp_var2");
compiled_scope.Var<CinnTensor>("var3");
auto compiled_tensor1 = compiled_scope.GetTensor("temp_var1");
compiled_tensor1->Resize(CinnShape({2, 3}));
auto compiled_tensor2 = compiled_scope.GetTensor("temp_var2");
compiled_tensor2->Resize(CinnShape({5, 6, 7}));
auto compiled_tensor3 = compiled_scope.GetTensor("var3");
compiled_tensor3->Resize(CinnShape({10}));
// expected result
InitializeTempVar({"temp_var1", "temp_var2"}, compiled_scope, place,
&temp_scope);
ASSERT_EQ(temp_scope.LocalVarNames().size(), 2);
auto* temp_var1 = temp_scope.FindVar("temp_var1");
ASSERT_NE(temp_var1, nullptr);
EXPECT_TRUE(temp_var1->IsType<LoDTensor>());
EXPECT_TRUE(temp_var1->Get<LoDTensor>().IsInitialized());
EXPECT_EQ(temp_var1->Get<LoDTensor>().dims(), framework::make_ddim({2, 3}));
auto* temp_var2 = temp_scope.FindVar("temp_var2");
ASSERT_NE(temp_var2, nullptr);
EXPECT_TRUE(temp_var2->IsType<LoDTensor>());
EXPECT_TRUE(temp_var2->Get<LoDTensor>().IsInitialized());
EXPECT_EQ(temp_var2->Get<LoDTensor>().dims(),
framework::make_ddim({5, 6, 7}));
}
TEST(CinnLaunchOpHelperTest, TestSharePaddleTensorWithCinnBuffer) {
// build test data
Scope scope;
platform::CPUPlace place;
auto* var1 = scope.Var("var1");
auto* tensor1 = var1->GetMutable<LoDTensor>();
tensor1->mutable_data<float>(framework::make_ddim({5, 6}), place);
auto* data1 = tensor1->data<float>();
data1[0] = 9.99;
data1[10] = 19.99;
ASSERT_EQ(tensor1->numel(), 30);
ASSERT_EQ(tensor1->dims().size(), 2);
// excepted result
cinn_buffer_t cinn_buffer;
SharePaddleTensorWithCinnBuffer(tensor1, &cinn_buffer);
ASSERT_NE(cinn_buffer.memory, nullptr);
ASSERT_EQ(cinn_buffer.num_elements(), 30);
auto* shadow_data = reinterpret_cast<float*>(cinn_buffer.memory);
EXPECT_FLOAT_EQ(shadow_data[0], 9.99);
EXPECT_FLOAT_EQ(shadow_data[10], 19.99);
}
TEST(CinnLaunchOpHelperTest, TestAppendExecutionArguments) {
// build test data
Scope scope;
platform::CPUPlace place;
auto* var1 = scope.Var("var1");
auto* tensor1 = var1->GetMutable<LoDTensor>();
tensor1->mutable_data<float>(framework::make_ddim({5, 6}), place);
auto* var2 = scope.Var("temp_var2");
auto* tensor2 = var2->GetMutable<LoDTensor>();
tensor2->mutable_data<float>(framework::make_ddim({10}), place);
// expected result
std::map<std::string, cinn_pod_value_t> name2argument;
std::vector<std::unique_ptr<cinn_buffer_t>> hold_buffers;
AppendExecutionArguments(scope, {"var1", "temp_var2"},
{{"var1", "cinn_var1"}}, &name2argument,
&hold_buffers);
ASSERT_EQ(name2argument.size(), 2);
ASSERT_EQ(hold_buffers.size(), 2);
EXPECT_NE(name2argument.count("cinn_var1"), 0);
EXPECT_NE(name2argument.count("temp_var2"), 0);
EXPECT_EQ(static_cast<cinn_buffer_t*>(name2argument.at("cinn_var1")),
hold_buffers.front().get());
EXPECT_EQ(static_cast<cinn_buffer_t*>(name2argument.at("temp_var2")),
hold_buffers.back().get());
}
} // namespace details
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdlib.h>
#include <random>
#include <string>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/init.h"
USE_OP(cinn_launch);
USE_OP(elementwise_add);
namespace paddle {
namespace operators {
using framework::LoDTensor;
using framework::ir::Graph;
using framework::ir::Node;
using framework::paddle2cinn::CinnCompiler;
std::unique_ptr<Graph> CreateOnlyElementwiseAddGraph(
const std::string& x_name, const std::string& y_name,
const std::string& out_name) {
auto g = std::make_unique<Graph>(framework::ProgramDesc());
framework::OpDesc feed_op_x, feed_op_y;
feed_op_x.SetType("feed");
feed_op_x.SetOutput("Out", {x_name});
feed_op_y.SetType("feed");
feed_op_y.SetOutput("Out", {y_name});
framework::VarDesc x_var(x_name);
framework::VarDesc y_var(y_name);
framework::VarDesc out_var(out_name);
framework::OpDesc elementwise_add_op;
elementwise_add_op.SetType("add");
elementwise_add_op.SetInput("X", {x_name});
elementwise_add_op.SetInput("Y", {y_name});
elementwise_add_op.SetOutput("Out", {out_name});
auto* feed_op_node_x = g->CreateOpNode(&feed_op_x);
auto* feed_op_node_y = g->CreateOpNode(&feed_op_y);
auto* elementwise_add_node = g->CreateOpNode(&elementwise_add_op);
auto* x_node = g->CreateVarNode(&x_var);
auto* y_node = g->CreateVarNode(&y_var);
auto* out_node = g->CreateVarNode(&out_var);
// fill op node
feed_op_node_x->outputs = {x_node};
feed_op_node_y->outputs = {y_node};
elementwise_add_node->inputs = {x_node, y_node};
elementwise_add_node->outputs = {out_node};
// fill variable node
x_node->inputs = {feed_op_node_x};
x_node->outputs = {elementwise_add_node};
y_node->inputs = {feed_op_node_y};
y_node->outputs = {elementwise_add_node};
out_node->inputs = {elementwise_add_node};
return g;
}
void CreateInputVariablesWithRandomData(
const std::vector<std::string>& variable_names,
const framework::DDim& common_ddim, framework::Scope* scope) {
std::random_device seed;
std::default_random_engine engine(seed());
std::uniform_real_distribution<float> dist(0.f, 2.f);
for (const auto& var_name : variable_names) {
auto* tensor = scope->Var(var_name)->GetMutable<LoDTensor>();
auto* data = tensor->mutable_data<float>(common_ddim, platform::CPUPlace());
for (auto i = 0; i < tensor->numel(); ++i) {
data[i] = dist(engine);
}
}
}
void CopyInputDataToPlace(const framework::Scope& scope,
const platform::Place& dst_place,
framework::Scope* dst_scope) {
for (const auto& var_name : scope.LocalVarNames()) {
const auto& src_tensor = scope.GetVar(var_name)->Get<LoDTensor>();
auto* dst_tensor = dst_scope->Var(var_name)->GetMutable<LoDTensor>();
TensorCopySync(src_tensor, dst_place, dst_tensor);
}
}
TEST(CinnLaunchOpTest, TestElementwiseAddPass) {
paddle::framework::InitDevices();
platform::SetNumThreads(1);
// cache test graph into CinnCompiler
const auto& test_out_name = "test_out";
const auto& expected_out_name = "expected_out";
auto compilation_key = CinnCompiler::GetInstance()->AddGraph(
CreateOnlyElementwiseAddGraph("test_x", "test_y", test_out_name));
// create cinn_launch_op and elementwise_add op
auto cinn_launch_op = paddle::framework::OpRegistry::CreateOp(
"cinn_launch", {{"X", {"test_x", "test_y"}}}, {{"Out", {test_out_name}}},
{{"compilation_key", compilation_key}});
auto elementwise_add_op = paddle::framework::OpRegistry::CreateOp(
"elementwise_add", {{"X", {"test_x"}}, {"Y", {"test_y"}}},
{{"Out", {expected_out_name}}}, {{}});
// prepare input data
framework::Scope init_scope;
CreateInputVariablesWithRandomData({"test_x", "test_y"}, {10, 20},
&init_scope);
// Run ops and check the computation results
auto run_and_check_fn = [&](const platform::Place& place) {
framework::Scope scope;
CopyInputDataToPlace(init_scope, place, &scope);
scope.Var(test_out_name)->GetMutable<LoDTensor>();
scope.Var(expected_out_name)->GetMutable<LoDTensor>();
cinn_launch_op->Run(scope, place);
elementwise_add_op->Run(scope, place);
LoDTensor test_out, expected_out;
if (platform::is_cpu_place(place)) {
test_out.ShareDataWith(scope.Var(test_out_name)->Get<LoDTensor>());
expected_out.ShareDataWith(
scope.Var(expected_out_name)->Get<LoDTensor>());
} else {
TensorCopySync(scope.Var(test_out_name)->Get<LoDTensor>(),
platform::CPUPlace(), &test_out);
TensorCopySync(scope.Var(expected_out_name)->Get<LoDTensor>(),
platform::CPUPlace(), &expected_out);
}
ASSERT_TRUE(test_out.IsInitialized());
ASSERT_TRUE(expected_out.IsInitialized());
ASSERT_EQ(test_out.dims(), expected_out.dims());
const auto* test_data = test_out.data<float>();
const auto* excepted_data = expected_out.data<float>();
for (auto i = 0; i < expected_out.numel(); ++i) {
EXPECT_FLOAT_EQ(test_data[i], excepted_data[i]);
}
};
LOG(INFO) << "Check compute result on cpu";
run_and_check_fn(platform::CPUPlace());
run_and_check_fn(platform::CPUPlace());
// create an new elementwise_add op
// because the above one cached the cpu kernel
LOG(INFO) << "Check compute result on gpu";
cinn_launch_op = paddle::framework::OpRegistry::CreateOp(
"cinn_launch", {{"X", {"test_x", "test_y"}}}, {{"Out", {test_out_name}}},
{{"compilation_key", compilation_key}});
elementwise_add_op = paddle::framework::OpRegistry::CreateOp(
"elementwise_add", {{"X", {"test_x"}}, {"Y", {"test_y"}}},
{{"Out", {expected_out_name}}}, {{}});
run_and_check_fn(platform::CUDAPlace());
run_and_check_fn(platform::CUDAPlace());
}
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册