未验证 提交 86bb6a01 编写于 作者: A Aurelius84 提交者: GitHub

[NewIR]Part-1: Add CinnJitInstruction for interpreter (#56302)

* [NewIR]Add CinnJitInstruction for interpreter

* fix windows compile error
上级 4d501872
......@@ -24,9 +24,9 @@ gather_srcs(
# TODO(Aurelius84): new_ir_compiler depends on pd_dialect and could
# not found under CINN_ONLY mode
if(NOT CINN_ONLY)
cinn_cc_library(new_ir_compiler SRCS new_ir_compiler.cc DEPS cinncore
cinn_cc_library(new_ir_compiler SRCS new_ir_compiler.cc DEPS cinnapi
pd_dialect)
cinn_cc_library(convert_to_dialect SRCS convert_to_dialect.cc DEPS cinncore
cinn_cc_library(convert_to_dialect SRCS convert_to_dialect.cc DEPS cinnapi
cinn_dialect)
endif()
......
......@@ -130,10 +130,11 @@ class Instruction {
}
}
int size() { return fn_ptrs_.size(); }
int size() const { return fn_ptrs_.size(); }
std::string DumpInstruction() const;
const std::string& function_name() const { return function_name_; }
const std::vector<std::vector<std::string>>& GetInArgs() const {
return in_args_;
}
......
......@@ -15,7 +15,6 @@
#include "paddle/cinn/hlir/framework/new_ir_compiler.h"
#include <absl/types/variant.h>
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
......
......@@ -3,3 +3,10 @@ cc_library(
SRCS instruction_base.cc phi_kernel_instruction.cc
legacy_kernel_instruction.cc instruction_util.cc
DEPS phi framework_proto)
if(WITH_CINN AND NOT CINN_ONLY)
cc_library(
cinn_jit_instruction
SRCS cinn_jit_instruction.cc
DEPS phi cinnapi cinn_dialect)
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h"
#include "paddle/cinn/hlir/dialect/jit_kernel_op.h"
#include "paddle/cinn/hlir/dialect/runtime_dialect.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/fluid/framework/paddle2cinn/transform_type.h"
namespace paddle {
namespace framework {
// TODO(Aurelius84): Think deeply what's the responsibility is it.
// Currently it assumes CinnLaunchContext role.
class JitContext {
public:
cinn_buffer_t* GetCinnBufferOfVar(const std::string& name) {
auto res = paddle2argument_.find(name);
PADDLE_ENFORCE_NE(
res,
paddle2argument_.end(),
platform::errors::NotFound(
"Variable(%s) not found in compilation result", name));
return static_cast<cinn_buffer_t*>(res->second);
}
// NOTE(Aurelius84): Before running each instruction, we should share Tensor
// memory from paddle scope with cinn_buffer_t from cinn scope including
// inputs and outputs.
void ShareMemToCinn(const std::string& var_name,
const phi::Place& place,
Scope* scope) {
cinn_buffer_t* buffer = GetCinnBufferOfVar(var_name);
auto* tensor = scope->GetVar(var_name)->GetMutable<phi::DenseTensor>();
// TODO(Aurelius84): Maybe we should consider to unify the Scope
// structure between paddle and cinn, so that we don't need to develop
// the glue code.
buffer->memory = reinterpret_cast<uint8_t*>(tensor->mutable_data(
place, paddle2cinn::TransToPaddleDataType(buffer->type)));
}
// TODO(Aurelius84): Add logic to parse stream for different device.
void* GetStream() { return nullptr; }
private:
// because a cinn_pod_value_t does not own a cinn_buffer_t object,
// an extra stroage is necessary to keep those objects and they can
// not be released until the runtime program finish execution.
std::vector<std::unique_ptr<cinn_buffer_t>> hold_buffers_;
// this map saves all execution arguments with their cinn names as key,
// and it is passed to the Execute interface of a cinn runtime program.
std::map<std::string, cinn_pod_value_t> name2argument_;
// this map saves all execution arguments with paddle variables as key,
// this map conbine name2argument_ and paddle2cinn_varmap_
std::map<std::string, cinn_pod_value_t> paddle2argument_;
};
// TODO(Aurelius84): Impl should hold JitContext instance to
// deliver the device context for 'instr->Run' and responsible
// to deal with inner buffer_t shareing between framework::Scope
// and cinn::Scope.
class CinnJitInstruction::Impl {
using Instruction = cinn::hlir::framework::Instruction;
public:
explicit Impl(Instruction* instr) : instr_(instr) {}
// TODO(Aurelus84): Support to specify name2podargs and stream arguments.
void Run() {
PADDLE_ENFORCE_NOT_NULL(
instr_, platform::errors::NotFound("instr_ should not be NULL"));
instr_->Run(/*name2podargs=*/nullptr,
false,
/*stream=*/nullptr,
/*use_cache=*/true);
}
const Instruction* pointer() const { return instr_; }
private:
Instruction* instr_{nullptr};
};
CinnJitInstruction::CinnJitInstruction(size_t id,
const platform::Place& place,
::ir::Operation* op,
Scope* scope)
: InstructionBase(id, place) {
// TODO(Aurelius84): We shall simplify members of JitKernelOp to make it
// only hold related function ptrs. Impl is the real runtime data structure
// responsible to construct hlir::framework::Instruction.
auto jit_kernel_op = op->dyn_cast<cinn::dialect::JitKernelOp>();
impl_ = std::make_shared<Impl>(jit_kernel_op.instruction());
}
void CinnJitInstruction::Run() {
VLOG(6) << "Run cinn jit_kernel_op : " << Name();
impl_->Run();
}
const std::string& CinnJitInstruction::Name() const {
// TODO(Aurelius84): Consider the case for instrucitons constaning
// multipule function ptrs and function names.
return impl_->pointer()->function_name();
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"
namespace ir {
class Operation;
}
namespace paddle {
namespace framework {
class Scope;
class CinnJitInstruction : public InstructionBase {
public:
CinnJitInstruction(size_t id,
const platform::Place& place,
::ir::Operation* op,
Scope* scope);
// TODO(Aurelius84): Only implement core interface and need implement GC and
// Event logic.
void Run() override;
const std::string& Name() const override;
private:
class Impl;
std::shared_ptr<Impl> impl_{nullptr};
};
} // namespace framework
} // namespace paddle
......@@ -34,6 +34,10 @@ set(INTERPRETER_DEPS
${DEVICE_EVENT_LIBS}
glog)
if(WITH_CINN AND NOT CINN_ONLY)
set(INTERPRETER_DEPS ${INTERPRETER_DEPS} cinn_jit_instruction)
endif()
cc_library(
interpreter
SRCS ${INTERPRETER_SRCS}
......
......@@ -36,6 +36,9 @@
#include "paddle/fluid/platform/flags.h"
#include "paddle/phi/backends/device_manager.h"
#ifdef PADDLE_WITH_CINN
#include "paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h"
#endif
#include "paddle/fluid/framework/new_executor/instruction/legacy_kernel_instruction.h"
#include "paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h"
#include "paddle/fluid/ir/dialect/utils.h"
......@@ -450,9 +453,14 @@ void NewIRInterpreter::BuildInstruction() {
var_name_2_id_,
variable_2_var_name_));
}
#ifdef PADDLE_WITH_CINN
} else if (op->dialect()->name() == "cinn") {
vec_instruction_base_.emplace_back(
std::make_unique<CinnJitInstruction>(op_idx++, place_, op, scope_));
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Now only support pd or pd_kernel dialect."));
"Now only support pd_kernel and cinn dialect."));
}
}
}
......
......@@ -11,4 +11,14 @@ if(WITH_TESTING AND WITH_CINN)
gtest
glog)
set_tests_properties(test_new_ir_compiler PROPERTIES LABELS "RUN_TYPE=CINN")
cc_test_old(
test_jit_instruction
SRCS
jit_instruction_test.cc
DEPS
interpreter
new_ir_compiler
convert_to_dialect)
set_tests_properties(test_jit_instruction PROPERTIES LABELS "RUN_TYPE=CINN")
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
#include "paddle/cinn/hlir/dialect/jit_kernel_op.h"
#include "paddle/cinn/hlir/dialect/runtime_dialect.h"
#include "paddle/cinn/hlir/framework/convert_to_dialect.h"
#include "paddle/cinn/hlir/framework/new_ir_compiler.h"
#include "paddle/cinn/utils/data_util.h"
std::unique_ptr<::ir::Program> BuildProgram() {
::ir::IrContext* ctx = ::ir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
auto program = std::make_unique<::ir::Program>(ctx);
::ir::Builder builder = ::ir::Builder(ctx, program->block());
const float value = 2.0;
auto full_op_x =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{64, 128},
value,
phi::DataType::FLOAT32,
phi::GPUPlace());
auto full_op_y =
builder.Build<paddle::dialect::FullOp>(std::vector<int64_t>{128, 64},
value,
phi::DataType::FLOAT32,
phi::GPUPlace());
return std::move(program);
}
namespace paddle {
namespace framework {
TEST(CinnJitInstruction, Run) {
// Step 1: Construct ir::Program
std::unique_ptr<::ir::Program> program = BuildProgram();
EXPECT_EQ(program->block()->size(), 2u);
// Step 2: Compiler New ir::Program into Runtime Program
auto target = cinn::common::DefaultNVGPUTarget();
auto scope = cinn::hlir::framework::BuildScope(target, *program);
ASSERT_EQ(scope->var_names().size(), 2);
cinn::hlir::framework::NewIRCompiler ir_compiler(*program, target, scope);
auto runtime_program = ir_compiler.Build();
// Step 3: Convert into cinn::dialect::RuntimeDialect
std::unique_ptr<::ir::Program> ir_runtime_program =
cinn::hlir::framework::ConvertToRuntimeDialect(*runtime_program);
std::set<std::string> out_names;
for (auto& var_name : scope->var_names()) {
std::string name = {var_name.begin(), var_name.end()};
out_names.insert(name);
}
platform::Place place = platform::CUDAPlace(0);
Scope exe_scope;
InterpreterCore executor(
place, {}, std::move(ir_runtime_program), &exe_scope);
executor.SetSkipGcVars(out_names);
executor.Run({});
// TODO(Aurelius84): Need to replace check with framework::Scope.
const float value = 2.0;
for (auto& name : out_names) {
std::vector<float> data =
cinn::GetTensorData<float>(scope->GetTensor(name), target);
for (int i = 0; i < data.size(); ++i) {
LOG_FIRST_N(INFO, 3) << "data: " << data[i];
ASSERT_NEAR(data[i], value, 1e-5);
}
}
}
} // namespace framework
} // namespace paddle
......@@ -24,8 +24,6 @@
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/utils/data_util.h"
#include "paddle/cinn/hlir/dialect/jit_kernel_op.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册