未验证 提交 69e9f03e 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Reconstruct the Instruction for NewIrInterpreter (#55239)

* add inplace interface

* support inplace

* refine code

* fix bug

* fix bug

* refien code

* add file

* add interface

* refine code

* refine code

* add phi kernel instruction

* refine code

* add test

* delete unuse code

* add test

* add test

* add deps

* delete unused code

* fix bug

* fix bug
上级 7e4290c5
add_subdirectory(garbage_collector)
add_subdirectory(instruction)
add_subdirectory(interpreter)
add_subdirectory(workqueue)
......@@ -14,6 +15,7 @@ set(STANDALONE_EXECUTOR_DEPS
pd_op_to_kernel_pass
phi_kernel_adaptor
program_translator
instruction_base
ir)
cc_library(
......
cc_library(
instruction_base
SRCS instruction_base.cc phi_kernel_instruction.cc
DEPS phi framework_proto)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
namespace paddle {
namespace framework {
InstructionBase::InstructionBase(size_t id, const platform::Place& place) {
id_ = id;
is_artificial_ = false;
if (platform::is_cpu_place(place)) {
type_ = OpFuncType::kCpuSync;
} else {
PADDLE_ENFORCE_EQ(
interpreter::IsSupportedHeterPlace(place),
true,
phi::errors::Fatal("Unsupported current place %s", place));
type_ = OpFuncType::kGpuAsync;
}
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place);
}
OpFuncType InstructionBase::KernelType() const { return type_; }
const platform::DeviceContext& InstructionBase::DeviceContext() const {
return *dev_ctx_;
}
void InstructionBase::RecordEvent(const Place& place) const {
platform::RecordEvent record(
"RecordStreamEvent", platform::TracerEventType::UserDefined, 10);
if (event_to_record_) {
VLOG(6) << "Record event at instruction: " << id_;
event_to_record_->event_->Record(dev_ctx_);
}
}
void InstructionBase::WaitEvent(const Place& place) const {
// If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(place)) {
return;
}
for (const EventInter& event_iter : events_to_wait_) {
platform::RecordEvent record(
"WaitStreamEvent", platform::TracerEventType::UserDefined, 10);
VLOG(6) << "Wait instruction: " << event_iter.instr_id_
<< " 's event with waiter_type: " << event_iter.waiter_type_;
event_iter.event_->Wait(event_iter.waiter_type_, dev_ctx_);
}
}
void InstructionBase::AddGCCheckVar(size_t id) { gc_check_vars_.push_back(id); }
const std::vector<size_t>& InstructionBase::GCCheckVars() const {
return gc_check_vars_;
}
const std::vector<std::pair<Variable*, Variable*>>&
InstructionBase::InplaceInfo() const {
return vec_inplace_in_to_out_;
}
void InstructionBase::AddInplace(Variable* in, Variable* out) {
vec_inplace_in_to_out_.emplace_back(in, out);
}
void InstructionBase::ClearInplace() { vec_inplace_in_to_out_.clear(); }
void InstructionBase::SetInputs(
const std::map<std::string, std::vector<int>>& inputs) {
input_index_ = inputs;
}
void InstructionBase::SetOutputs(
const std::map<std::string, std::vector<int>>& outputs) {
output_index_ = outputs;
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/platform/event.h"
namespace paddle {
namespace framework {
using SchedulingPriority = int64_t;
class InstructionBase {
public:
explicit InstructionBase(size_t id, const platform::Place& place);
virtual ~InstructionBase() = default;
size_t Id() const { return id_; }
bool IsArtificial() const { return is_artificial_; }
void SetArtificial(bool is_artificial) { is_artificial_ = is_artificial; }
OpFuncType KernelType() const;
void SetKernelType(OpFuncType type) { type_ = type; }
int GetStreamPriority() const { return scheduling_priority_; }
void SetStreamPriority(SchedulingPriority scheduling_priority) {
scheduling_priority_ = scheduling_priority;
}
SchedulingPriority GetSchedulingPriority() const {
return scheduling_priority_;
}
void SetSchedulingPriority(SchedulingPriority priority) {
scheduling_priority_ = priority;
}
const std::string& GetExecutionStream() const { return execution_stream_; }
void SetExecutionStream(const std::string& stream) {
execution_stream_ = stream;
}
const platform::DeviceContext& DeviceContext() const;
void SetDeviceContext(platform::DeviceContext* ctx) { dev_ctx_ = ctx; }
const std::vector<size_t>& NextInstrsInDifferenceThread() const {
return next_instrs_in_different_thread_;
}
void AddNextInstrInDifferentThread(size_t id) {
next_instrs_in_different_thread_.push_back(id);
}
const std::vector<size_t>& NextInstrsInSameThread() const {
return next_instrs_in_same_thread_;
}
void AddNextInstrInSameThread(size_t id) {
next_instrs_in_same_thread_.push_back(id);
}
const EventInter& EventToRecord() const { return *event_to_record_; }
void AddEventToRecord(std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) {
event_to_record_ = std::make_shared<EventInter>(id_, event, waiter_type);
}
const std::vector<EventInter>& EventsToWait() const {
return events_to_wait_;
}
void AddEventToWait(size_t instr_id,
std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) {
events_to_wait_.emplace_back(instr_id, event, waiter_type);
}
void RecordEvent(const Place& place) const;
void WaitEvent(const Place& place) const;
const std::vector<size_t>& GCCheckVars() const;
void AddGCCheckVar(size_t id);
const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const;
void AddInplace(Variable* in, Variable* out);
void ClearInplace();
std::map<int, int>& GetMutableInplaceBackMap() { return inplace_back_map_; }
const std::map<int, int>& GetInplaceBackMap() { return inplace_back_map_; }
const std::map<std::string, std::vector<int>>& Inputs() const {
return input_index_;
}
std::map<std::string, std::vector<int>>& GetMutableInputs() {
return input_index_;
}
void SetInputs(const std::map<std::string, std::vector<int>>& inputs);
const std::map<std::string, std::vector<int>>& Outputs() const {
return output_index_;
}
std::map<std::string, std::vector<int>>& GetMutableOutputs() {
return output_index_;
}
void SetOutputs(const std::map<std::string, std::vector<int>>& outputs);
virtual void Run() = 0;
private:
size_t id_;
bool is_artificial_; // Instruction is artificial means that it is only used
// to assist scheduling and no need to be executed.
OpFuncType type_;
// dist attrs:lower value, higher priority
int stream_priority_{0};
SchedulingPriority scheduling_priority_{0};
std::string execution_stream_{kDefaultStream};
platform::DeviceContext* dev_ctx_; // not owned
std::vector<size_t> next_instrs_in_different_thread_;
std::vector<size_t> next_instrs_in_same_thread_;
std::shared_ptr<EventInter> event_to_record_;
std::vector<EventInter> events_to_wait_;
std::vector<size_t> gc_check_vars_;
std::vector<std::pair<Variable*, Variable*>>
vec_inplace_in_to_out_; // If not use share data, need this ?
std::map<int, int> inplace_back_map_;
std::map<std::string, std::vector<int>> input_index_;
std::map<std::string, std::vector<int>> output_index_;
};
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/interface/infermeta.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/interface/op_yaml_info_parser.h"
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/core/type_defs.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/value.h"
namespace paddle {
namespace framework {
OpFuncType AnalyseOpFuncType(ir::Operation* op, const platform::Place& place) {
if (platform::is_cpu_place(place)) {
return OpFuncType::kCpuSync;
}
PADDLE_ENFORCE_EQ(interpreter::IsSupportedHeterPlace(place),
true,
phi::errors::Fatal("Unsupported current place %s", place));
// Some GPU OPs do not launch CUDA Kernel, but spend a lot of time on CPU
// computing. They execute serially in device thread and block CUDA kernel
// launching in other GPU OPs. To improve performance, set them as kGpuSync
// and so that they would be dispatched to host thread.
auto op_attributes = op->attributes();
auto op_name =
op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().data();
if (op_name == kCoalesceTensor &&
(!platform::is_xpu_place(place) ||
op->attribute<ir::BoolAttribute>("persist_output").data() == false) &&
op->attribute<ir::BoolAttribute>("set_constant").data() == false &&
op->attribute<ir::BoolAttribute>("copy_data").data() == false) {
return OpFuncType::kGpuSync;
}
// for memcpy explicitly called by user
if (platform::is_gpu_place(place) && op_name == interpreter::kMemcpyD2H) {
return OpFuncType::kGpuSync;
}
if (op_name == "shape") {
return OpFuncType::kGpuSync;
}
return OpFuncType::kGpuAsync;
}
PhiKernelInstruction::PhiKernelInstruction(
size_t id,
const platform::Place& place,
ir::Operation* op,
Scope* scope,
Scope* local_scope,
const std::unordered_map<::ir::Value, std::string>& value_2_name_map)
: InstructionBase(id, place) {
auto op_attributes = op->attributes();
auto op_name =
op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().data();
ir::OpInfo op_info = ir::IrContext::Instance()->GetRegisteredOpInfo(op_name);
phi_op_name_ = op_name;
if (op_name == "builtin.combine" || op_name == "pd.feed" ||
op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter") {
VLOG(6) << "skip process " << op_name;
SetArtificial(true);
return;
}
// Todo: support paddle::dialect::DistAttribute
// if (op_attributes.count("dist_attr") != 0) {
// if (op_attributes.count("execution_stream") != 0) {
// SetExecutionStream(op_attributes.at("execution_stream")
// .dyn_cast<::ir::StrAttribute>()
// .data());
// }
// if (op_attributes.count("stream_priority") != 0) {
// SetStreamPriority(op_attributes.at("stream_priority")
// .dyn_cast<::ir::Int32Attribute>()
// .data());
// }
// if (op_attributes.count("scheduling_priority") != 0) {
// SetSchedulingPriority(op_attributes.at("scheduling_priority")
// .dyn_cast<::ir::Int64Attribute>()
// .data());
// }
// } else {
// if (interpreter::IsCommunicationOp(op)) {
// // NOTE(Ruibiao): Dispatching computation before communication
// improves
// // multi-stream overlap when the time cost of communication less than
// // that of the calculation (e.g., ResNet50_bs128_pure_fp16 N4C32
// // training).
// op_func_node.scheduling_priority_ = 1;
// }
// }
SetKernelType(AnalyseOpFuncType(op, place));
infer_meta_interface_ =
op_info.GetInterfaceImpl<paddle::dialect::InferMetaInterface>();
auto yaml_interface =
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>();
paddle::dialect::OpYamlInfoParser yaml_info_parser(
yaml_interface->get_op_info_());
::ir::BuildPhiContext<
phi::InferMetaContext,
phi::MetaTensor,
phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>(op,
value_2_name_map,
scope,
local_scope,
yaml_info_parser,
&infer_meta_context_);
VLOG(6) << "finish process infer meta context";
auto kernel_name =
op_attributes.at("kernel_name").dyn_cast<ir::StrAttribute>().data();
auto kernel_key = op_attributes.at("kernel_key")
.dyn_cast<paddle::dialect::KernelAttribute>()
.data();
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, kernel_key);
phi_kernel_ = new phi::Kernel(kernel_result.kernel);
PADDLE_ENFORCE_EQ(
phi_kernel_->IsValid(), true, "not found kernel for [%s]", kernel_name);
VLOG(6) << "finish process select kernel";
::ir::BuildPhiContext<phi::KernelContext,
const phi::TensorBase*,
phi::TensorBase*,
paddle::small_vector<const phi::TensorBase*>,
paddle::small_vector<phi::TensorBase*>,
true>(op,
value_2_name_map,
scope,
local_scope,
yaml_info_parser,
&kernel_context_,
&(GetMutableInputs()),
&(GetMutableOutputs()));
kernel_context_.SetDeviceContext(phi::DeviceContextPool::Instance().Get(
phi::TransToPhiPlace(kernel_key.backend())));
VLOG(6) << "finish process kernel context";
SetDeviceContext(phi::DeviceContextPool::Instance().Get(
phi::TransToPhiPlace(kernel_key.backend())));
VLOG(6) << "finish process device context";
}
void PhiKernelInstruction::Run() {
VLOG(5) << "Run op " << phi_op_name_ << " infer meta.";
infer_meta_interface_->infer_meta_(&(infer_meta_context_));
VLOG(5) << "Run op " << phi_op_name_ << " kernel.";
(*(phi_kernel_))(&(kernel_context_));
}
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"
namespace ir {
class Operation;
} // namespace ir
namespace paddle {
namespace framework {
class Scope;
class Value;
class PhiKernelInstruction : public InstructionBase {
public:
PhiKernelInstruction(
size_t id,
const platform::Place& place,
::ir::Operation* op,
Scope* scope,
Scope* local_scope,
const std::unordered_map<::ir::Value, std::string>& value_2_name_map);
const std::string& PhiOpName() const { return phi_op_name_; }
phi::Kernel* PhiKernel() const { return phi_kernel_; }
const phi::KernelContext& KernelContext() const { return kernel_context_; }
const phi::InferMetaContext& InferMetaContext() const {
return infer_meta_context_;
}
paddle::dialect::InferMetaInterface::Concept* InferMetaInterface() const {
return infer_meta_interface_;
}
void Run() override;
private:
std::string phi_op_name_;
paddle::dialect::InferMetaInterface::Concept* infer_meta_interface_{
nullptr}; // not owned
phi::InferMetaContext infer_meta_context_;
phi::KernelContext kernel_context_;
phi::Kernel* phi_kernel_{nullptr}; // not owned
};
} // namespace framework
} // namespace paddle
......@@ -72,6 +72,12 @@ class InterpreterBaseImpl {
virtual paddle::framework::FetchList Run(
const std::vector<std::string>& feed_names, bool need_fetch = true) = 0;
// NOTE(zhangbo): This interface is only used for temporary testing and only
// for testing during the iteration process of the new IR access actuator
// version. It will be deleted in the future.
virtual paddle::framework::FetchList BetaRun(
const std::vector<std::string>& feed_names, bool need_fetch = true) = 0;
virtual void ShareWorkQueueFrom(InterpreterBaseImpl* src) = 0;
virtual void SetCopyProgram(std::shared_ptr<ProgramDesc> prog) = 0;
......
......@@ -72,6 +72,11 @@ FetchList InterpreterCore::Run(const std::vector<std::string>& feed_names,
return impl_->Run(feed_names, need_fetch);
}
FetchList InterpreterCore::BetaRun(const std::vector<std::string>& feed_names,
bool need_fetch) {
return impl_->BetaRun(feed_names, need_fetch);
}
void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src) {
impl_->ShareWorkQueueFrom(const_cast<InterpreterBaseImpl*>(src->Impl()));
}
......
......@@ -51,6 +51,9 @@ class InterpreterCore {
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true);
paddle::framework::FetchList BetaRun(
const std::vector<std::string>& feed_names, bool need_fetch = true);
void ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src);
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog);
......
......@@ -36,6 +36,7 @@
#include "paddle/fluid/platform/flags.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h"
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h"
namespace paddle {
......@@ -231,6 +232,39 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
}
}
FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
bool need_fetch) {
SetDeviceId(place_);
if (!is_build_) {
LOG_FIRST_N(INFO, 1) << "New Executor is BetaRunning.";
::ir::BuildScope(
*ir_program_->block(), scope_, local_scope_, &value_2_var_name_map_);
BuildInstruction();
for (size_t instr_id = 0; instr_id < vec_instruction_base_.size();
++instr_id) {
vec_instruction_base_[instr_id]->Run();
}
} else {
for (size_t instr_id = 0; instr_id < vec_instruction_base_.size();
++instr_id) {
vec_instruction_base_[instr_id]->Run();
}
}
if (HasLocalScope()) {
ClearLoDTensorArrayInLocalScope();
}
// return Fetch Tensors
Scope* inner_scope = InnerScope();
auto* fetch_var = inner_scope->FindVar(interpreter::kFetchVarName);
if (fetch_var && need_fetch) {
auto fetch_list = std::move(*fetch_var->GetMutable<framework::FetchList>());
return fetch_list;
} else {
return {};
}
}
void NewIRInterpreter::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) {
copy_program_ = prog;
}
......@@ -1479,5 +1513,32 @@ void NewIRInterpreter::AnalyseExecuteOrderForTrace() {
trace_execute_order_ = trace_order;
}
/// ======================== ///
/// For new ir ///
/// ======================== ///
void NewIRInterpreter::BuildInstruction() {
VLOG(0) << "Build Instructions for new ir ... ";
vec_instruction_base_.clear();
size_t op_idx = 0;
for (auto it = ir_program_->block()->begin();
it != ir_program_->block()->end();
++it) {
VLOG(0) << "Build Instruction for op: " << op_idx;
if ((*it)->dialect()->name() == "pd_kernel") {
vec_instruction_base_.emplace_back(
std::make_unique<PhiKernelInstruction>(op_idx++,
place_,
(*it),
scope_,
local_scope_,
value_2_var_name_map_));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Now only support pd_kernel dialect."));
}
}
}
} // namespace framework
} // namespace paddle
......@@ -14,6 +14,7 @@
#pragma once
#include <memory>
#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"
#include "paddle/fluid/framework/new_executor/interpreter_base_impl.h"
namespace ir {
......@@ -46,6 +47,10 @@ class NewIRInterpreter : public InterpreterBaseImpl {
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true) override;
paddle::framework::FetchList BetaRun(
const std::vector<std::string>& feed_names,
bool need_fetch = true) override;
void ShareWorkQueueFrom(InterpreterBaseImpl* src) override;
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog) override;
......@@ -178,8 +183,15 @@ class NewIRInterpreter : public InterpreterBaseImpl {
std::vector<HookFunc> hookfuncs_;
/// ======================== ///
/// For new ir ///
/// ======================== ///
void BuildInstruction();
std::unique_ptr<::ir::Program> ir_program_{nullptr};
std::vector<std::unique_ptr<InstructionBase>> vec_instruction_base_;
std::unordered_map<::ir::Value, std::string> value_2_var_name_map_;
};
......
......@@ -219,6 +219,11 @@ FetchList ProgramInterpreter::Run(const std::vector<std::string>& feed_names,
}
}
FetchList ProgramInterpreter::BetaRun(
const std::vector<std::string>& feed_names, bool need_fetch) {
return {};
}
void ProgramInterpreter::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) {
copy_program_ = prog;
}
......
......@@ -48,6 +48,10 @@ class ProgramInterpreter : public InterpreterBaseImpl {
paddle::framework::FetchList Run(const std::vector<std::string>& feed_names,
bool need_fetch = true) override;
paddle::framework::FetchList BetaRun(
const std::vector<std::string>& feed_names,
bool need_fetch = true) override;
void ShareWorkQueueFrom(InterpreterBaseImpl* src) override;
void SetCopyProgram(std::shared_ptr<ProgramDesc> prog) override;
......
......@@ -18,17 +18,17 @@ OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand({input_i
"""
OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return result({output_index}); }}
"""
OP_GET_ATTRIBUTE_TEMPLATE = """ ir::Attribute attribute(const std::string &name) {{
OP_GET_ATTRIBUTE_TEMPLATE = """ ir::Attribute attribute(const std::string &name) {
PADDLE_ENFORCE(attributes().count(name) > 0,
phi::errors::PreconditionNotMet("Attribute is not exist."));
return attributes().at(name);
}}
}
template <typename T>
T attribute(const std::string &name) {{
T attribute(const std::string &name) {
PADDLE_ENFORCE(attributes().count(name) > 0 && attributes().at(name).isa<T>(),
phi::errors::PreconditionNotMet("Attribute is not right."));
return attributes().at(name).dyn_cast<T>();
}}
}
"""
......
......@@ -329,10 +329,10 @@ void BuildPhiContext(
ctx->EmplaceBackOutput(out_ptr);
} else if (out_type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::DenseTensor*>(
&(scope->Var(name)->Get<phi::DenseTensor>()))));
&(inner_scope->Var(name)->Get<phi::DenseTensor>()))));
} else if (out_type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::SelectedRows*>(
&(scope->Var(name)->Get<phi::SelectedRows>()))));
&(inner_scope->Var(name)->Get<phi::SelectedRows>()))));
} else if (out_type.isa<ir::VectorType>()) {
OutListType outputs;
auto& variable_array =
......
......@@ -17,6 +17,7 @@
#include <ostream>
#include <vector>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/macros.h"
#include "paddle/ir/core/op_info.h"
#include "paddle/ir/core/operation_utils.h"
......@@ -66,6 +67,13 @@ class IR_API alignas(8) Operation final {
const AttributeMap &attributes() const { return attributes_; }
template <typename T>
T attribute(const std::string &name) {
IR_ENFORCE(attributes().count(name) > 0 && attributes().at(name).isa<T>(),
"Attribute is not right.");
return attributes().at(name).dyn_cast<T>();
}
void set_attribute(const std::string &key, Attribute value) {
attributes_[key] = value;
}
......
# skip win32 since wget is not installed by default on windows machine.
if(NOT WIN32)
cc_test(
standalone_executor_new_ir_test
SRCS standalone_executor_new_ir_test.cc
DEPS phi_kernel_adaptor pd_dialect ir)
endif()
set(OPS
fill_constant_op
uniform_random_op
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
#include <gtest/gtest.h>
#include <chrono>
#include <iostream>
#include <string>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/platform/init_phi.h"
DECLARE_FILE_SYMBOLS(kernel_dialect);
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(full_int_array, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(uniform, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sqrt, CPU, ALL_LAYOUT);
bool simple_cmp(float a, float b) { return std::abs((a - b) / a) < 1e-5; }
namespace paddle {
namespace framework {
TEST(StandaloneExecutor, run) {
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program((ctx));
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Builder builder = ir::Builder(ctx, program.block());
paddle::dialect::FullOp op1 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::FullOp op2 = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());
builder.Build<paddle::dialect::AddOp>(op1->result(0), op2->result(0));
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
auto place = platform::CPUPlace();
Scope scope;
ProgramDesc prog_desc;
InterpreterCore test_core(place, std::move(kernel_program), &scope);
test_core.BetaRun({});
auto out_tensor = test_core.local_scope() == nullptr
? scope.FindVar("inner_var_2")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar("inner_var_2")
->Get<phi::DenseTensor>();
bool res0 = simple_cmp(out_tensor.data<float>()[0], 2.0);
bool res1 = simple_cmp(out_tensor.data<float>()[1], 2.0);
bool res2 = simple_cmp(out_tensor.data<float>()[2], 2.0);
bool res3 = simple_cmp(out_tensor.data<float>()[3], 2.0);
EXPECT_EQ(res0, true);
EXPECT_EQ(res1, true);
EXPECT_EQ(res2, true);
EXPECT_EQ(res3, true);
}
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册