未验证 提交 9d0baeab 编写于 作者: T TeFeng Chen 提交者: GitHub

Add cinn_instruction_run_op for launching execution of a cinn instruction (#39435)

* add cinn_instruction_run_op for launching execution of a cinn instruction

* fix multi definition compilation error

* update cmake

* fix bug at infershape

* fix compile error due to lacking header file
上级 cf8a5573
......@@ -88,7 +88,7 @@ const CinnCompiledObject& CinnCompiler::Compile(
if (cache_by_struct_.count(cur_key_by_struct) != 0) {
exist = true;
cache_by_address_[cur_key_by_address] =
cache_by_struct_.at(cur_key_by_struct).get();
cache_by_struct_.at(cur_key_by_struct);
}
}
}
......@@ -98,12 +98,13 @@ const CinnCompiledObject& CinnCompiler::Compile(
CompileGraph(graph, input_tensors, target, compiled_num, stream);
pten::AutoWRLock w_guard{&rwlock_};
if (!cache_by_struct_.count(cur_key_by_struct)) {
cache_by_address_[cur_key_by_address] = compiled_res.get();
cache_by_struct_[cur_key_by_struct] = std::move(compiled_res);
cache_by_address_[cur_key_by_address] = compiled_num;
cache_by_struct_[cur_key_by_struct] = compiled_num;
index2cache_.emplace(compiled_num, std::move(compiled_res));
}
}
pten::AutoRDLock guard{&rwlock_};
const auto& cached_boj = *cache_by_address_[cur_key_by_address];
const auto& cached_boj = *index2cache_[cache_by_address_[cur_key_by_address]];
return cached_boj;
}
......@@ -115,6 +116,15 @@ const CinnCompiledObject& CinnCompiler::Compile(
return Compile(graph, input_tensors, target, stream);
}
const CinnCompiledObject& CinnCompiler::GetCompiledObject(
int64_t cached_index) const {
auto res = index2cache_.find(cached_index);
PADDLE_ENFORCE_NE(res, index2cache_.end(),
platform::errors::InvalidArgument(
"Index(%ld) not found in cache", cached_index));
return *res->second;
}
std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
std::string graph_key;
ProgramDesc program;
......@@ -202,6 +212,7 @@ void CinnCompiler::Clear() {
graphs_.clear();
cache_by_address_.clear();
cache_by_struct_.clear();
index2cache_.clear();
}
real_compiled_num_.store(0);
}
......@@ -240,6 +251,7 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
compiled_obj->launch_context =
std::make_unique<operators::details::CinnLaunchContext>(
compiled_obj->paddle2cinn_varmap, compiled_obj->scope);
compiled_obj->cached_index = compiled_num;
return compiled_obj;
}
......
......@@ -53,6 +53,7 @@ struct CinnCompiledObject {
std::shared_ptr<::cinn::hlir::framework::Scope> scope;
std::unordered_map<std::string, std::string> paddle2cinn_varmap;
std::unique_ptr<operators::details::CinnLaunchContext> launch_context;
std::int64_t cached_index;
};
// Entrance to use CINN.
......@@ -76,6 +77,8 @@ class CinnCompiler {
const std::map<std::string, const LoDTensor*>& input_tensors,
const ::cinn::common::Target& target, void* stream = nullptr);
const CinnCompiledObject& GetCompiledObject(int64_t cached_index) const;
std::string AddGraph(std::unique_ptr<ir::Graph> graph);
const ir::Graph& FindGraph(const std::string& graph_key) const;
......@@ -101,12 +104,12 @@ class CinnCompiler {
void* stream = nullptr) const;
std::unordered_map<std::string, std::unique_ptr<ir::Graph>> graphs_;
std::unordered_map<CinnCacheKeyByAddress, CinnCompiledObject*,
CinnCacheKey::Hash>
std::unordered_map<CinnCacheKeyByAddress, std::int64_t, CinnCacheKey::Hash>
cache_by_address_;
std::unordered_map<CinnCacheKeyByStructure,
std::unique_ptr<CinnCompiledObject>, CinnCacheKey::Hash>
std::unordered_map<CinnCacheKeyByStructure, std::int64_t, CinnCacheKey::Hash>
cache_by_struct_;
std::unordered_map<std::int64_t, std::unique_ptr<CinnCompiledObject>>
index2cache_;
std::atomic_int64_t real_compiled_num_{0};
mutable pten::RWLock rwlock_;
......
......@@ -270,13 +270,20 @@ TEST(CinnCompilerTest, Compile) {
auto compile_fn = [&](const Target& target) {
const auto& compiled_obj =
cinn_compiler->Compile(compiling_graph, input_tensors, target);
ASSERT_NE(compiled_obj.compiler, nullptr);
ASSERT_NE(compiled_obj.runtime_program, nullptr);
ASSERT_NE(compiled_obj.scope, nullptr);
ASSERT_FALSE(compiled_obj.paddle2cinn_varmap.empty());
ASSERT_NE(compiled_obj.launch_context, nullptr);
const auto& cached_obj =
cinn_compiler->Compile(compilation_key, input_tensors, target);
ASSERT_EQ(reinterpret_cast<std::uint64_t>(&compiled_obj),
reinterpret_cast<std::uint64_t>(&cached_obj));
ASSERT_EQ(cached_obj.cached_index + 1, cinn_compiler->real_compiled_num());
const auto& ret_obj =
cinn_compiler->GetCompiledObject(cached_obj.cached_index);
ASSERT_EQ(reinterpret_cast<std::uint64_t>(&compiled_obj),
reinterpret_cast<std::uint64_t>(&ret_obj));
};
// GPU Compilation
......
include(operators)
register_operators(EXCLUDES cinn_launch_op)
cc_library(cinn_op_helper SRCS cinn_op_helper.cc DEPS operator device_context)
cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS ddim lod_tensor scope cinn)
op_library(cinn_launch_op SRCS cinn_launch_op.cc cinn_launch_op.cu.cc DEPS string_helper cinn cinn_compiler cinn_launch_context)
SET(CINN_OP_DEPS string_helper cinn cinn_compiler cinn_op_helper cinn_launch_context)
register_operators(DEPS ${CINN_OP_DEPS})
if (WITH_TESTING)
cc_test(cinn_launch_context_test SRCS cinn_launch_context_test.cc DEPS ddim lod_tensor scope cinn_launch_context)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/cinn/cinn_instruction_run_op.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle::operators {
class CinnInstructionRunOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnInstructionRun");
OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs,
"CinnInstructionRun");
const CinnCompiledObject& compiled_object =
CinnCompiler::GetInstance()->GetCompiledObject(
ctx->Attrs().Get<int64_t>(kCachedIndex));
details::CinnLaunchContext* launch_context =
compiled_object.launch_context.get();
std::vector<std::string> output_args = ctx->Outputs(kOutputs);
std::vector<framework::DDim> output_dims(output_args.size());
std::transform(output_args.begin(), output_args.end(), output_dims.begin(),
[launch_context](const std::string& var_name) {
cinn_buffer_t* buffer =
launch_context->GetCinnBufferOfVar(var_name);
return framework::DDim(buffer->dims, buffer->dimensions);
});
ctx->SetOutputsDim(kOutputs, output_dims);
}
};
class CinnInstructionRunOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(kX,
"(vector<LoDTensor>)"
"which are the input arguments of this cinn instruction")
.AsDuplicable();
AddOutput(kOutputs,
"(vector<LoDTensor>)"
"which are the output arguments of this cinn instruction")
.AsDuplicable();
AddAttr<int64_t>(
kCachedIndex,
"(int64_t)"
"the stored index of the cached compilation result in CinnCompiler,"
"which is used to fetch the CinnCompiledObject where this cinn "
"instruction is included");
AddAttr<int64_t>(
kInstructionIndex,
"(int64_t)"
"the index of this instruction to the cinn runtime program");
AddComment(R"DOC(
CinnInstructionRun Operator.
This operator is used to launch a
CINN(https://github.com/PaddlePaddle/CINN/blob/develop/README.md) instruction execution
Both the input and output of this operator are a set of variables
which are the input and output arguments of the bound cinn instruction respectively.
In addition, there is an attribute named 'cached_index' should be
set necessarily to get the CinnCompiledObject where the instruction is included
and 'instruction_index' is fetch the instruction object from complied runtime prograrm.
It accomplishes the execution of the instruction according to the following steps:
0. Set the shapes ot the output variables at InferShape function with
compilation result.
1. Fetch the cinn instruction bound to this operator by 'cached_index'
and 'instruction_index' from CinnCompiler.
2. Prepare the input and output variables of the instruction in Paddle and share
their buffers to CINN by setting 'memory' of according cinn_buffer_t.
3. Launch CINN runtime to execute the instruction.
)DOC");
}
};
} // namespace paddle::operators
namespace ops = paddle::operators;
using CPUDeviceContext = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(
cinn_instruction_run, ops::CinnInstructionRunOp,
ops::CinnInstructionRunOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
cinn_instruction_run,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, bool>,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, int>,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, int64_t>,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, float>,
ops::CinnInstructionRunOpKernel<CPUDeviceContext, double>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cinn/cinn_instruction_run_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace ops = paddle::operators;
using CUDADeviceContext = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(
cinn_instruction_run,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, bool>,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, int>,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, int64_t>,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, float>,
ops::CinnInstructionRunOpKernel<CUDADeviceContext, double>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iterator>
#include <memory>
#include <string>
#include <vector>
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/instruction.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include "paddle/fluid/operators/cinn/cinn_op_helper.h"
namespace paddle::operators {
using CinnInstruction = ::cinn::hlir::framework::Instruction;
using CinnCompiledObject = framework::paddle2cinn::CinnCompiledObject;
using CinnCompiler = framework::paddle2cinn::CinnCompiler;
template <typename DeviceContext, typename T>
class CinnInstructionRunOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// step 1: fetch the cinn instruction bound to this operator
auto cached_index = ctx.template Attr<int64_t>(kCachedIndex);
auto ins_index = ctx.template Attr<int64_t>(kInstructionIndex);
const CinnCompiledObject& compiled_object =
CinnCompiler::GetInstance()->GetCompiledObject(cached_index);
const std::vector<std::unique_ptr<CinnInstruction>>& instructions =
compiled_object.runtime_program->GetRunInstructions();
PADDLE_ENFORCE_LT(ins_index, instructions.size(),
platform::errors::InvalidArgument(
"Index(%ld) > instructions.size(%ld).", ins_index,
instructions.size()));
auto&& instruction = instructions.at(ins_index);
// step 2: prepare the input and output arguments of the instruction
details::CinnLaunchContext* launch_context =
compiled_object.launch_context.get();
auto share_argument_buffer_fn = [launch_context,
&ctx](const std::string& var_name) {
cinn_buffer_t* buffer = launch_context->GetCinnBufferOfVar(var_name);
framework::Variable* var = ctx.scope().GetVar(var_name);
auto* tensor = var->template GetMutable<framework::LoDTensor>();
buffer->memory =
reinterpret_cast<uint8_t*>(tensor->mutable_data<T>(ctx.GetPlace()));
};
std::vector<std::string> in_args = ctx.InputNames(kX);
std::for_each(in_args.begin(), in_args.end(), share_argument_buffer_fn);
std::vector<std::string> out_args = ctx.OutputNames(kOutputs);
std::for_each(out_args.begin(), out_args.end(), share_argument_buffer_fn);
// step 3: launch CINN runtime to execute the instruction
// TODO(CtfGo): simplify format of arguments package as a vector in CINN
// and update this usage call
instruction->Run(&launch_context->FinalizeArguments(), false,
details::GetStream<DeviceContext>(ctx));
}
};
} // namespace paddle::operators
......@@ -24,12 +24,31 @@ CinnLaunchContext::CinnLaunchContext(
const std::unordered_map<std::string, std::string>& paddle2cinn_varmap,
const std::shared_ptr<CinnScope>& cinn_scope)
: paddle2cinn_varmap_(paddle2cinn_varmap), cinn_scope_(cinn_scope) {
// generate all names of cinn used variables
auto var_names = cinn_scope_->var_names();
cinn_variable_names_.reserve(var_names.size());
std::transform(
var_names.begin(), var_names.end(),
std::inserter(cinn_variable_names_, cinn_variable_names_.end()),
[](const auto& name_view) { return std::string(name_view.data()); });
// build the variable name map of cinn2paddle
for (const auto& x : paddle2cinn_varmap_) {
auto res = cinn2paddle_varmap_.emplace(x.second, x.first);
PADDLE_ENFORCE_EQ(
res.second, true,
platform::errors::InvalidArgument(
"Cinn variable(%s) maps to more than one paddle variable(%s,%s)",
x.second, res.first->second, x.first));
}
// supplement the relations of the remain variables not appearing in above
// map,
// they are internal variables and here we use the name from cinn compiled.
for (const auto& var_name : cinn_variable_names_) {
if (!cinn2paddle_varmap_.count(var_name)) {
cinn2paddle_varmap_.emplace(var_name, var_name);
paddle2cinn_varmap_.emplace(var_name, var_name);
}
}
}
void CinnLaunchContext::UpdateCapturedEnv(const framework::Scope& scope,
......@@ -189,6 +208,20 @@ CinnLaunchContext::FinalizeArguments() const {
return name2argument_;
}
cinn_buffer_t* CinnLaunchContext::GetCinnBufferOfVar(
const std::string& paddle_var_name) {
auto res = paddle2cinn_varmap_.find(paddle_var_name);
PADDLE_ENFORCE_NE(
res, paddle2cinn_varmap_.end(),
platform::errors::InvalidArgument(
"Variable(%s) not found in compilation result", paddle_var_name));
auto it = name2argument_.find(res->second);
PADDLE_ENFORCE_NE(it, name2argument_.end(),
platform::errors::InvalidArgument(
"Argument(%s) not be initialized", res->second));
return static_cast<cinn_buffer_t*>(it->second);
}
} // namespace details
} // namespace operators
} // namespace paddle
......@@ -64,6 +64,8 @@ class CinnLaunchContext {
// Finalize all execution arguments and return them
const std::map<std::string, cinn_pod_value_t>& FinalizeArguments() const;
cinn_buffer_t* GetCinnBufferOfVar(const std::string& paddle_var_name);
private:
// Get CinnTensor with CINN variable name
CinnTensor GetCinnTensor(const std::string& var_name);
......@@ -84,19 +86,22 @@ class CinnLaunchContext {
std::unique_ptr<framework::Scope> cached_temp_scope_ = nullptr;
// a variable name map from paddle to cinn
const std::unordered_map<std::string, std::string>& paddle2cinn_varmap_;
std::unordered_map<std::string, std::string> paddle2cinn_varmap_;
// a variable name map from cinn to paddle
std::unordered_map<std::string, std::string> cinn2paddle_varmap_;
// the variable scope of cinn
const std::shared_ptr<CinnScope> cinn_scope_;
// all variables used by compiled executable program
// all names of cinn variables used by compiled executable program
std::unordered_set<std::string> cinn_variable_names_;
// because a cinn_pod_value_t does not own the cinn_buffer_t object,
// an extra stroage is necessary to keep the object and it can
// not be released until runtime program finish execution.
// not be released until the runtime program finish execution.
std::vector<std::unique_ptr<cinn_buffer_t>> hold_buffers_;
// name to execution argument
// this map saves all execution arguments with their cinn names as key,
// and it is passed to the Execute interface of a cinn runtime program.
std::map<std::string, cinn_pod_value_t> name2argument_;
};
......
......@@ -13,10 +13,11 @@
// limitations under the License.
#include "paddle/fluid/operators/cinn/cinn_launch_op.h"
#include <functional>
#include <vector>
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/runtime/cinn_runtime.h"
#include "cinn/runtime/flags.h"
#include "paddle/fluid/string/string_helper.h"
DECLARE_bool(cudnn_deterministic);
......
......@@ -13,36 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cinn/cinn_launch_op.h"
#include <memory>
#include <vector>
#include "cinn/runtime/cinn_runtime.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
namespace paddle {
namespace operators {
namespace details {
#ifdef PADDLE_WITH_CUDA
template <>
void* GetStream<platform::CUDADeviceContext>(
const framework::ExecutionContext& ctx) {
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
return dev_ctx.stream();
}
#endif
} // namespace details
} // namespace operators
} // namespace paddle
/* see [Why use single type kernel] */
REGISTER_OP_CUDA_KERNEL(cinn_launch,
......
......@@ -18,27 +18,18 @@
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/scope.h"
#include "cinn/runtime/cinn_runtime.h"
#include "cinn/runtime/flags.h"
#include "cinn/common/target.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
#include "paddle/fluid/operators/cinn/cinn_op_helper.h"
namespace paddle {
namespace operators {
constexpr char kX[] = "X";
constexpr char kNoNeedBufferX[] = "NoNeedBufferX";
constexpr char kOutputs[] = "Out";
constexpr char kCompilationKey[] = "compilation_key";
using LoDTensor = framework::LoDTensor;
using CinnTensor = ::cinn::hlir::framework::Tensor;
using CinnScope = ::cinn::hlir::framework::Scope;
using CinnCompiler = framework::paddle2cinn::CinnCompiler;
using CinnCompiledObject = framework::paddle2cinn::CinnCompiledObject;
......@@ -57,17 +48,6 @@ void LaunchCinnExecution(const CinnCompiledObject& compiled_obj,
// Set cinn FLAGS (such as FLAGS_cinn_cudnn_deterministic) with paddle's FLAGS.
void SetCinnRuntimeFlags();
template <typename DeviceContext>
void* GetStream(const framework::ExecutionContext& ctx) {
return nullptr;
}
#ifdef PADDLE_WITH_CUDA
template <>
void* GetStream<platform::CUDADeviceContext>(
const framework::ExecutionContext& ctx);
#endif
} // namespace details
template <typename DeviceContext, typename T>
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/cinn/cinn_op_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle::operators::details {
#ifdef PADDLE_WITH_CUDA
template <>
void* GetStream<platform::CUDADeviceContext>(
const framework::ExecutionContext& ctx) {
const auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
return dev_ctx.stream();
}
#endif
} // namespace paddle::operators::details
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/operator.h"
// We define some common names or utility functions
// for operators related to cinn in this file
namespace paddle::operators {
// input params, output params and attributes
constexpr char kX[] = "X";
constexpr char kNoNeedBufferX[] = "NoNeedBufferX";
constexpr char kOutputs[] = "Out";
constexpr char kCompilationKey[] = "compilation_key";
constexpr char kCachedIndex[] = "cached_index";
constexpr char kInstructionIndex[] = "instruction_index";
// utility functions
namespace details {
template <typename DeviceContext>
void* GetStream(const framework::ExecutionContext& ctx) {
return nullptr;
}
#ifdef PADDLE_WITH_CUDA
template <>
void* GetStream<platform::CUDADeviceContext>(
const framework::ExecutionContext& ctx);
#endif
} // namespace details
} // namespace paddle::operators
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册