未验证 提交 251f68e7 编写于 作者: X xiongkun 提交者: GitHub

Add Support for OperatorBase in new executor (#36945)

* add scope as membership

* functions complete

* fix bugs: garbage collectior

* deal unknow variable holder

* add

* 1. add unittest for operator_base

* code format
上级 68c3e2cb
...@@ -316,13 +316,18 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -316,13 +316,18 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
auto place = instr_node.DeviceContext().GetPlace(); auto place = instr_node.DeviceContext().GetPlace();
VLOG(4) << place << " " << op->DebugStringEx(global_scope_); VLOG(4) << place << " " << op->DebugStringEx(global_scope_);
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
{ {
platform::RecordEvent infershape_event("InferShape"); platform::RecordEvent infershape_event("InferShape");
static_cast<const framework::OperatorWithKernel*>(instr_node.OpBase()) // If it is OperatorBase, InferShape do nothing.
->InferShape(instr_node.InnerInferShapeContext().get()); if (op_with_kernel != nullptr)
op_with_kernel->InferShape(instr_node.InnerInferShapeContext().get());
} }
if (FLAGS_new_executor_use_inplace) { if (op_with_kernel != nullptr &&
FLAGS_new_executor_use_inplace) { // TODO(xiongkun03) Does operator
// base support
// inplace ?
for (auto& pair : instr_node.InplaceInfo()) { for (auto& pair : instr_node.InplaceInfo()) {
const auto& in = paddle::framework::details::GetTensorFromVar(pair.first); const auto& in = paddle::framework::details::GetTensorFromVar(pair.first);
auto* out = auto* out =
...@@ -334,6 +339,9 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -334,6 +339,9 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
} }
{ {
platform::RecordEvent compute_event("Compute"); platform::RecordEvent compute_event("Compute");
if (op_with_kernel == nullptr)
instr_node.OpBase()->Run(*global_scope_->GetScope(), place_);
else
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
} }
...@@ -357,7 +365,9 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -357,7 +365,9 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
// for debug nan/inf // for debug nan/inf
if (FLAGS_check_nan_inf) { if (FLAGS_check_nan_inf) {
VLOG(4) << "Check nan/inf"; VLOG(4) << "Check nan/inf";
framework::details::CheckOpHasNanOrInf(*op, *global_scope_, place); framework::details::CheckOpHasNanOrInf(
*op, *global_scope_,
place); // TODO(xiongkun03) change it to inner scope.
} }
} }
......
...@@ -58,6 +58,11 @@ void InterpreterCoreGarbageCollector::Add(paddle::framework::Variable* var, ...@@ -58,6 +58,11 @@ void InterpreterCoreGarbageCollector::Add(paddle::framework::Variable* var,
const platform::DeviceContext* ctx) { const platform::DeviceContext* ctx) {
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
Add(var->GetMutable<LoDTensor>()->MoveMemoryHolder(), event, ctx); Add(var->GetMutable<LoDTensor>()->MoveMemoryHolder(), event, ctx);
} else if (var->IsType<
operators::reader::
OrderedMultiDeviceLoDTensorBlockingQueueHolder>()) {
// var->Clear(); // TODO(xiongkun03) can we clear directly? Why we must use
// Add interface?
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
Add(var->GetMutable<SelectedRows>()->mutable_value()->MoveMemoryHolder(), Add(var->GetMutable<SelectedRows>()->mutable_value()->MoveMemoryHolder(),
event, ctx); event, ctx);
...@@ -66,6 +71,10 @@ void InterpreterCoreGarbageCollector::Add(paddle::framework::Variable* var, ...@@ -66,6 +71,10 @@ void InterpreterCoreGarbageCollector::Add(paddle::framework::Variable* var,
for (auto& t : *tensor_arr) { for (auto& t : *tensor_arr) {
Add(t.MoveMemoryHolder(), event, ctx); Add(t.MoveMemoryHolder(), event, ctx);
} }
} else if (var->IsType<std::vector<Scope*>>()) {
// NOTE(@xiongkun03) conditional_op / while_op will create a STEP_SCOPE
// refer to executor.cc to see what old garbage collector does.
// do nothing, because the sub scope will be deleted by sub-executor.
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"The variable(%s) is not supported in eager deletion.", "The variable(%s) is not supported in eager deletion.",
......
...@@ -229,6 +229,28 @@ void apply_device_guard(const OperatorBase* op_base, ...@@ -229,6 +229,28 @@ void apply_device_guard(const OperatorBase* op_base,
} }
} }
void deal_operator_base(const platform::Place& place,
const VariableScope* var_scope, OperatorBase* op_base,
OpFuncNode* op_func_node) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// input, output is prepared. set the other attributes.
op_func_node->operator_base_ = op_base;
op_func_node->type_ = OpFuncType::kQueueSync; // alway Sync
op_func_node->kernel_func_ = nullptr;
op_base->Run(*var_scope->GetScope(), place); // Run without data transformer.
std::unordered_set<int> no_data_transform_index;
for (auto& it : op_func_node->input_index) {
for (auto& id : it.second) {
no_data_transform_index.emplace(id);
}
}
op_func_node->no_data_transform_index =
no_data_transform_index; // all index is no-need-transform
op_func_node->dev_ctx_ = dev_ctx;
}
// the return value is whether data transformer is needed for this var // the return value is whether data transformer is needed for this var
bool need_place_transform_for_var(const OpKernelType& kernel_type_for_var, bool need_place_transform_for_var(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key) { const OpKernelType& expected_kernel_key) {
...@@ -429,11 +451,18 @@ void build_op_func_list(const platform::Place& place, ...@@ -429,11 +451,18 @@ void build_op_func_list(const platform::Place& place,
OpFuncNode op_func_node; OpFuncNode op_func_node;
op_func_node.input_index = ins_name2id; op_func_node.input_index = ins_name2id;
op_func_node.output_index = outs_name2id; op_func_node.output_index = outs_name2id;
if (dynamic_cast<const framework::OperatorWithKernel*>(op_base) ==
nullptr) {
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base(place, var_scope, op_base, &op_func_node);
} else {
// construct RuntimeContext and analysis KernelType // construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {}); RuntimeContext runtime_context({}, {});
runtime_context.inputs.swap(ins_map); runtime_context.inputs.swap(ins_map);
runtime_context.outputs.swap(outs_map); runtime_context.outputs.swap(outs_map);
InterpretercoreInferShapeContext infer_shape_ctx(*op_base, runtime_context); InterpretercoreInferShapeContext infer_shape_ctx(*op_base,
runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT inheritted // TODO(Aurelius84): In case of control flow ops, they are NOT inheritted
// from OperatorWithKernel. // from OperatorWithKernel.
static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape( static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape(
...@@ -447,7 +476,8 @@ void build_op_func_list(const platform::Place& place, ...@@ -447,7 +476,8 @@ void build_op_func_list(const platform::Place& place,
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
Scope scope; Scope scope;
auto expected_kernel_key = auto expected_kernel_key =
...@@ -484,7 +514,6 @@ void build_op_func_list(const platform::Place& place, ...@@ -484,7 +514,6 @@ void build_op_func_list(const platform::Place& place,
PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s", PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s",
expected_kernel_key.place_)); expected_kernel_key.place_));
} }
if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) { if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
...@@ -494,15 +523,17 @@ void build_op_func_list(const platform::Place& place, ...@@ -494,15 +523,17 @@ void build_op_func_list(const platform::Place& place,
ExecutionContext(*op_base, scope, *dev_ctx, runtime_context); ExecutionContext(*op_base, scope, *dev_ctx, runtime_context);
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
PADDLE_ENFORCE_NE(kernel_iter, kernels.end(), PADDLE_ENFORCE_NE(
kernel_iter, kernels.end(),
platform::errors::NotFound( platform::errors::NotFound(
"Operator (%s) does not have kernel for %s.", "Operator (%s) does not have kernel for %s.", op->Type(),
op->Type(), KernelTypeToString(expected_kernel_key))); KernelTypeToString(expected_kernel_key)));
op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
op_func_node.kernel_func_(exec_ctx); op_func_node.kernel_func_(exec_ctx);
vec_func_list->push_back(op_func_node); }
vec_func_list->push_back(op_func_node);
// gc--------------------------------------------------------------------------- // gc---------------------------------------------------------------------------
auto iter = unused_var_map.find(op_base); auto iter = unused_var_map.find(op_base);
if (iter == unused_var_map.end()) { if (iter == unused_var_map.end()) {
......
...@@ -472,6 +472,12 @@ struct VariableMetaInfo { ...@@ -472,6 +472,12 @@ struct VariableMetaInfo {
}; };
// TODO(zhiqiu): Maybe we need to add rwlock for VariableScope? // TODO(zhiqiu): Maybe we need to add rwlock for VariableScope?
// NOTE(xiongkun03): Use scope as a member of VariableScope, we don't need
// ScopeBase.
// Scope manager the variables and VariableScope is just a
// quick
// access machanism.
class VariableScope : public ScopeBase { class VariableScope : public ScopeBase {
public: public:
VariableScope() { VariableScope() {
...@@ -482,7 +488,10 @@ class VariableScope : public ScopeBase { ...@@ -482,7 +488,10 @@ class VariableScope : public ScopeBase {
info.var_ref_count_ = 0; info.var_ref_count_ = 0;
info.vardesc_ = nullptr; info.vardesc_ = nullptr;
vec_meta_info_.push_back(info); vec_meta_info_.push_back(info);
scope_ptr_.reset(new Scope());
} }
const Scope* GetScope() const { return scope_ptr_.get(); }
Variable* FindVar(const std::string& name) const { Variable* FindVar(const std::string& name) const {
auto it = name2id_.find(name); auto it = name2id_.find(name);
if (it != name2id_.end()) { if (it != name2id_.end()) {
...@@ -540,11 +549,14 @@ class VariableScope : public ScopeBase { ...@@ -540,11 +549,14 @@ class VariableScope : public ScopeBase {
void AddVar(const std::string& name, VarDesc* var_desc) { // NOLINT void AddVar(const std::string& name, VarDesc* var_desc) { // NOLINT
name2id_[name] = VarSize(); name2id_[name] = VarSize();
auto v = new Variable(); auto v = scope_ptr_->Var(name);
if (nullptr == var_desc) { if (nullptr == var_desc) {
v->GetMutable<LoDTensor>(); v->GetMutable<LoDTensor>();
} else { } else {
InitializeVariable(v, var_desc->GetType()); InitializeVariable(
v,
var_desc
->GetType()); // Scope don't initialize variable recently created
} }
var_list_.push_back(v); var_list_.push_back(v);
...@@ -555,8 +567,12 @@ class VariableScope : public ScopeBase { ...@@ -555,8 +567,12 @@ class VariableScope : public ScopeBase {
} }
void AddVar(const std::string& name, Variable& var) { // NOLINT void AddVar(const std::string& name, Variable& var) { // NOLINT
// must copy.
VLOG(4) << "Add variable: " << name << " through AddVar()";
auto v = scope_ptr_->Var(name);
*v = var;
name2id_[name] = VarSize(); name2id_[name] = VarSize();
var_list_.push_back(&var); var_list_.push_back(v);
VariableMetaInfo info; VariableMetaInfo info;
info.var_ref_count_ = 0; info.var_ref_count_ = 0;
...@@ -595,6 +611,7 @@ class VariableScope : public ScopeBase { ...@@ -595,6 +611,7 @@ class VariableScope : public ScopeBase {
std::vector<Variable*> var_list_; std::vector<Variable*> var_list_;
std::map<std::string, int> name2id_; std::map<std::string, int> name2id_;
std::vector<VariableMetaInfo> vec_meta_info_; std::vector<VariableMetaInfo> vec_meta_info_;
std::unique_ptr<Scope> scope_ptr_;
}; };
class NextInstruction { class NextInstruction {
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import unittest
import paddle
from paddle.fluid import core
from paddle.fluid.core import StandaloneExecutor
import paddle.fluid as fluid
from paddle.fluid.framework import Program, program_guard
import paddle.fluid.layers as layers
import numpy as np
paddle.enable_static()
# test the compatibility of new executor: run old
# and new executor twice and check the result.
# please override the _get_feeds() and build_prgram()
class TestCompatibility(unittest.TestCase):
def setUp(self):
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
self.iter_run = 4
def _get_feed(self):
""" return the feeds
"""
return None
def build_program(self):
def true_func():
return layers.fill_constant(
shape=[1, 2], dtype='int32', value=1), layers.fill_constant(
shape=[2, 3], dtype='bool', value=True)
def false_func():
return layers.fill_constant(
shape=[3, 4], dtype='float32', value=3), layers.fill_constant(
shape=[4, 5], dtype='int64', value=2)
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
x = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
y = layers.fill_constant(shape=[1], dtype='float32', value=0.23)
pred = layers.less_than(x, y)
out = layers.cond(pred, true_func, false_func)
# out is a tuple containing 2 tensors
return main_program, startup_program, out
def _run(self, feed):
paddle.seed(2020)
main_program, startup_program, fetch_vars = self.build_program()
exe = paddle.static.Executor(self.place)
exe.run(startup_program)
ret = []
for i in range(self.iter_run):
ret.append(exe.run(main_program, feed=feed, fetch_list=fetch_vars))
return ret
def run_raw_executor(self, feed):
out = self._run(feed)
print("GT:", out)
return out
def run_new_executor(self, feed):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
out = self._run(feed)
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
print("New:", out)
return out
def test_with_feed(self):
feed = self._get_feed()
res = self.run_new_executor(feed)
gt = self.run_raw_executor(feed)
for x, y in zip(gt, res):
if isinstance(x, list):
for tx, ty in zip(x, y):
self.assertTrue(np.array_equal(tx, ty))
elif isinstance(x, np.ndarray):
self.assertTrue(np.array_equal(tx, ty))
else:
raise Exception("Not Implement!")
class TestWhile(TestCompatibility):
def _get_feed(self):
""" return the feeds
"""
return None
def build_program(self):
def cond(i, ten):
return i < ten
def body(i, ten):
i = i + 1
return [i, ten]
main_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with paddle.static.program_guard(main_program, startup_program):
i = paddle.full(
shape=[1], fill_value=0, dtype='int64') # loop counter
ten = paddle.full(
shape=[1], fill_value=10, dtype='int64') # loop length
i, ten = paddle.static.nn.while_loop(cond, body, [i, ten])
exe = paddle.static.Executor(paddle.CPUPlace())
return main_program, startup_program, i
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册