未验证 提交 dfcf746e 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #14904 from panyx0718/clean2

refactor RunImpl
......@@ -278,7 +278,8 @@ std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph::runtime::Backend::create("CPU");
void NgraphEngine::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
op->RuntimeInferShape(scope_, place_);
RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_);
op->RuntimeInferShape(scope_, place_, ctx);
for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) {
auto* var = scope_.FindVar(var_name);
......
......@@ -137,6 +137,23 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
}
}
RuntimeContext::RuntimeContext(const VariableNameMap& innames,
const VariableNameMap& outnames,
const Scope& scope) {
for (auto& var_name_item : innames) {
std::vector<Variable*>& input_vars = inputs[var_name_item.first];
for (auto& var_name : var_name_item.second) {
input_vars.push_back(scope.FindVar(var_name));
}
}
for (auto& var_name_item : outnames) {
std::vector<Variable*>& output_vars = outputs[var_name_item.first];
for (auto& var_name : var_name_item.second) {
output_vars.push_back(scope.FindVar(var_name));
}
}
}
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
VLOG(4) << place << " " << DebugStringEx(&scope);
if (platform::is_gpu_place(place)) {
......@@ -412,11 +429,48 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
return var != nullptr;
}
const Variable* ExecutionContext::InputVar(const std::string& name) const {
auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) return nullptr;
PADDLE_ENFORCE_LE(it->second.size(), 1UL,
"Operator %s's input %s should contain only one variable.",
op_.Type(), name);
return it->second.empty() ? nullptr : it->second[0];
}
const Variable* ExecutionContext::LegacyInputVar(
const std::string& name) const {
auto ipt = op_.Input(name);
return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
}
Variable* ExecutionContext::OutputVar(const std::string& name) const {
auto it = ctx_.outputs.find(name);
if (it == ctx_.outputs.end()) return nullptr;
PADDLE_ENFORCE_LE(it->second.size(), 1UL,
"Operator %s's output %s should contain only one variable.",
op_.Type(), name);
return it->second.empty() ? nullptr : it->second[0];
}
Variable* ExecutionContext::LegacyOutputVar(const std::string& name) const {
auto opt = op_.Output(name);
return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
}
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
return Input<LoDTensor>(name);
}
template <>
const Tensor* ExecutionContext::LegacyInput<Tensor>(
const std::string& name) const {
return LegacyInput<LoDTensor>(name);
}
template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const {
......@@ -441,6 +495,11 @@ Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
return Output<LoDTensor>(name);
}
template <>
Tensor* ExecutionContext::LegacyOutput<Tensor>(const std::string& name) const {
return LegacyOutput<LoDTensor>(name);
}
template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const {
......@@ -477,23 +536,22 @@ bool OpSupportGPU(const std::string& op_type) {
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope,
const RuntimeContext& ctx)
: op_(op), scope_(scope), ctx_(ctx) {}
bool HasInput(const std::string& name) const override {
// has only one input
const auto& ins = op_.Inputs();
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end()) {
return false;
}
const auto& in = it->second;
if (in.size() == 0 || in[0] == kEmptyVarName) {
return false;
}
if (in.size() == 0) return false;
PADDLE_ENFORCE_EQ(in.size(), 1UL,
"Input %s should not have more than one inputs", name);
return scope_.FindVar(in[0]) != nullptr;
return in[0] != nullptr;
}
bool HasOutput(const std::string& name) const override {
......@@ -678,6 +736,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
private:
const OperatorBase& op_;
const Scope& scope_;
const RuntimeContext& ctx_;
};
static void CheckTensorNANOrInf(const std::string& name,
......@@ -696,15 +755,15 @@ static void CheckTensorNANOrInf(const std::string& name,
}
void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
const platform::Place& place,
const RuntimeContext& ctx) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope, ctx);
this->InferShape(&infer_shape_ctx);
}
void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
RuntimeContext ctx(Inputs(), Outputs(), scope);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
......@@ -718,15 +777,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second;
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the
// transform functions are ready.
// for (auto& candidate : kKernelPriority) {
// Do selection
// }
auto expected_kernel_key =
this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx));
auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
......@@ -748,7 +800,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope =
TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars);
PrepareData(scope, expected_kernel_key, &transfered_inplace_vars, &ctx);
// exec scope is the scope that kernel actually executed on.
const Scope& exec_scope =
......@@ -758,7 +810,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx = pool.Get(expected_kernel_key.place_);
}
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx);
this->InferShape(&infer_shape_ctx);
// TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
// not Scope. Imperative mode only pass inputs and get outputs.
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, ctx));
if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered.
......@@ -782,6 +838,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
}
}
void OperatorWithKernel::TransferInplaceVarsBack(
const Scope& scope, const std::vector<std::string>& inplace_vars,
const Scope& transfer_scope) const {
......@@ -797,13 +854,19 @@ void OperatorWithKernel::TransferInplaceVarsBack(
}
}
Scope* OperatorWithKernel::TryTransferData(
Scope* OperatorWithKernel::PrepareData(
const Scope& scope, const OpKernelType& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars) const {
std::vector<std::string>* transfered_inplace_vars,
RuntimeContext* ctx) const {
Scope* new_scope = nullptr;
for (auto& var_name_item : Inputs()) {
for (auto& var_name : var_name_item.second) {
std::vector<Variable*>& input_vars = ctx->inputs[var_name_item.first];
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto& var_name = var_name_item.second[i];
auto* var = scope.FindVar(var_name);
input_vars[i] = var;
// Only tensor can be tranfer to another device.
if (var == nullptr || !VarIsTensor(*var)) {
continue;
......@@ -851,6 +914,7 @@ Scope* OperatorWithKernel::TryTransferData(
}
auto* trans_var = new_scope->Var(var_name);
input_vars[i] = trans_var;
Tensor out;
TransformData(expected_kernel_key, kernel_type_for_var, *tensor_in, &out);
......
......@@ -70,6 +70,15 @@ Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
class OperatorBase;
class ExecutionContext;
class RuntimeContext {
public:
RuntimeContext(const VariableNameMap& innames,
const VariableNameMap& outnames, const Scope& scope);
VariableValueMap inputs;
VariableValueMap outputs;
};
/**
* OperatorBase has the basic elements that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
......@@ -129,7 +138,8 @@ class OperatorBase {
void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
virtual void RuntimeInferShape(const Scope& scope,
const platform::Place& place) const {}
const platform::Place& place,
const RuntimeContext& ctx) const {}
protected:
std::string type_;
......@@ -156,8 +166,9 @@ class OperatorBase {
class ExecutionContext {
public:
ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext& device_context)
: op_(op), scope_(scope), device_context_(device_context) {}
const platform::DeviceContext& device_context,
const RuntimeContext& ctx)
: op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {}
const OperatorBase& op() const { return op_; }
......@@ -180,15 +191,9 @@ class ExecutionContext {
return op_.Outputs(name).size();
}
const Variable* InputVar(const std::string& name) const {
auto ipt = op_.Input(name);
return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
}
const Variable* InputVar(const std::string& name) const;
Variable* OutputVar(const std::string& name) const {
auto opt = op_.Output(name);
return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
}
Variable* OutputVar(const std::string& name) const;
const std::vector<const Variable*> MultiInputVar(
const std::string& name) const {
......@@ -227,6 +232,22 @@ class ExecutionContext {
return var == nullptr ? nullptr : var->GetMutable<T>();
}
template <typename T>
const T* LegacyInput(const std::string& name) const {
auto* var = LegacyInputVar(name);
return var == nullptr ? nullptr : &var->Get<T>();
}
template <typename T>
T* LegacyOutput(const std::string& name) const {
auto var = LegacyOutputVar(name);
return var == nullptr ? nullptr : var->GetMutable<T>();
}
const Variable* LegacyInputVar(const std::string& name) const;
Variable* LegacyOutputVar(const std::string& name) const;
template <typename T>
const std::vector<const T*> MultiInput(const std::string& name) const {
auto names = op_.Inputs(name);
......@@ -286,11 +307,16 @@ class ExecutionContext {
const OperatorBase& op_;
const Scope& scope_;
const platform::DeviceContext& device_context_;
const RuntimeContext& ctx_;
};
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;
template <>
const Tensor* ExecutionContext::LegacyInput<Tensor>(
const std::string& name) const;
template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const;
......@@ -298,6 +324,9 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
template <>
Tensor* ExecutionContext::LegacyOutput<Tensor>(const std::string& name) const;
template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const;
......@@ -350,8 +379,8 @@ class OperatorWithKernel : public OperatorBase {
OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
}
void RuntimeInferShape(const Scope& scope,
const platform::Place& place) const override;
void RuntimeInferShape(const Scope& scope, const platform::Place& place,
const RuntimeContext& ctx) const override;
protected:
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
......@@ -371,9 +400,10 @@ class OperatorWithKernel : public OperatorBase {
*
* * transfered_inplace_vars is a output vector.
*/
Scope* TryTransferData(
const Scope& scope, const OpKernelType& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars) const;
Scope* PrepareData(const Scope& scope,
const OpKernelType& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars,
RuntimeContext* ctx) const;
void TransferInplaceVarsBack(const Scope& scope,
const std::vector<std::string>& inplace_vars,
......
......@@ -28,8 +28,11 @@ class OperatorBase;
class OpDesc;
class InferShapeContext;
class BlockDesc;
class Variable;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// TODO(panyx0718): Replace vector with something like gtl::Vector.
using VariableValueMap = std::map<std::string, std::vector<Variable*>>;
// The order should be as same as framework.proto
using Attribute =
......
......@@ -122,7 +122,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(dev_place);
framework::ExecutionContext ctx(*this, scope, dev_ctx);
framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope);
framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx);
const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册