提交 840e6729 编写于 作者: X Xin Pan

inject context

test=develop
上级 bbff0df3
...@@ -278,19 +278,7 @@ std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ = ...@@ -278,19 +278,7 @@ std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph::runtime::Backend::create("CPU"); ngraph::runtime::Backend::create("CPU");
void NgraphEngine::GetNgInputShape(std::shared_ptr<OperatorBase> op) { void NgraphEngine::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
RuntimeContext ctx; RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_);
for (auto& var_name_item : op->Inputs()) {
std::vector<Variable*> input_vars = ctx.inputs[var_name_item.first];
for (auto& var_name : var_name_item.second) {
input_vars.push_back(scope_.FindVar(var_name));
}
}
for (auto& var_name_item : op->Outputs()) {
std::vector<Variable*> output_vars = ctx.outputs[var_name_item.first];
for (auto& var_name : var_name_item.second) {
output_vars.push_back(scope_.FindVar(var_name));
}
}
op->RuntimeInferShape(scope_, place_, ctx); op->RuntimeInferShape(scope_, place_, ctx);
for (auto& var_name_item : op->Inputs()) { for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
......
...@@ -137,6 +137,23 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { ...@@ -137,6 +137,23 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
} }
} }
RuntimeContext::RuntimeContext(const VariableNameMap& innames,
const VariableNameMap& outnames,
const Scope& scope) {
for (auto& var_name_item : innames) {
std::vector<Variable*>& input_vars = inputs[var_name_item.first];
for (auto& var_name : var_name_item.second) {
input_vars.push_back(scope.FindVar(var_name));
}
}
for (auto& var_name_item : outnames) {
std::vector<Variable*>& output_vars = outputs[var_name_item.first];
for (auto& var_name : var_name_item.second) {
output_vars.push_back(scope.FindVar(var_name));
}
}
}
void OperatorBase::Run(const Scope& scope, const platform::Place& place) { void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
VLOG(4) << place << " " << DebugStringEx(&scope); VLOG(4) << place << " " << DebugStringEx(&scope);
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
...@@ -704,6 +721,7 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope, ...@@ -704,6 +721,7 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeContext ctx(Inputs(), Outputs(), scope);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
...@@ -717,15 +735,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -717,15 +735,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the auto expected_kernel_key = this->GetExpectedKernelType(
// transform functions are ready. ExecutionContext(*this, scope, *dev_ctx, ctx));
// for (auto& candidate : kKernelPriority) {
// Do selection
// }
auto expected_kernel_key =
this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
...@@ -744,7 +755,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -744,7 +755,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
KernelTypeToString(expected_kernel_key)); KernelTypeToString(expected_kernel_key));
} }
RuntimeContext ctx;
// do data transformScope &transfer_scope; // do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope = auto* transfer_scope =
...@@ -760,7 +770,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -760,7 +770,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx); RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx)); kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, ctx));
if (!transfered_inplace_vars.empty()) { if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered. // there is inplace variable has been transfered.
...@@ -784,6 +794,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -784,6 +794,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
} }
} }
void OperatorWithKernel::TransferInplaceVarsBack( void OperatorWithKernel::TransferInplaceVarsBack(
const Scope& scope, const std::vector<std::string>& inplace_vars, const Scope& scope, const std::vector<std::string>& inplace_vars,
const Scope& transfer_scope) const { const Scope& transfer_scope) const {
...@@ -806,7 +817,6 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -806,7 +817,6 @@ Scope* OperatorWithKernel::PrepareData(
Scope* new_scope = nullptr; Scope* new_scope = nullptr;
for (auto& var_name_item : Inputs()) { for (auto& var_name_item : Inputs()) {
std::vector<Variable*>& input_vars = ctx->inputs[var_name_item.first]; std::vector<Variable*>& input_vars = ctx->inputs[var_name_item.first];
input_vars.resize(var_name_item.second.size());
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto& var_name = var_name_item.second[i]; auto& var_name = var_name_item.second[i];
...@@ -869,8 +879,6 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -869,8 +879,6 @@ Scope* OperatorWithKernel::PrepareData(
} }
for (auto& var_name_item : Outputs()) { for (auto& var_name_item : Outputs()) {
std::vector<Variable*>& output_vars = ctx->outputs[var_name_item.first]; std::vector<Variable*>& output_vars = ctx->outputs[var_name_item.first];
output_vars.resize(var_name_item.second.size());
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto& var_name = var_name_item.second[i]; auto& var_name = var_name_item.second[i];
output_vars[i] = scope.FindVar(var_name); output_vars[i] = scope.FindVar(var_name);
......
...@@ -72,7 +72,8 @@ class ExecutionContext; ...@@ -72,7 +72,8 @@ class ExecutionContext;
class RuntimeContext { class RuntimeContext {
public: public:
RuntimeContext() {} RuntimeContext(const VariableNameMap& innames,
const VariableNameMap& outnames, const Scope& scope);
VariableValueMap inputs; VariableValueMap inputs;
VariableValueMap outputs; VariableValueMap outputs;
...@@ -165,8 +166,9 @@ class OperatorBase { ...@@ -165,8 +166,9 @@ class OperatorBase {
class ExecutionContext { class ExecutionContext {
public: public:
ExecutionContext(const OperatorBase& op, const Scope& scope, ExecutionContext(const OperatorBase& op, const Scope& scope,
const platform::DeviceContext& device_context) const platform::DeviceContext& device_context,
: op_(op), scope_(scope), device_context_(device_context) {} const RuntimeContext& ctx)
: op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {}
const OperatorBase& op() const { return op_; } const OperatorBase& op() const { return op_; }
...@@ -295,6 +297,7 @@ class ExecutionContext { ...@@ -295,6 +297,7 @@ class ExecutionContext {
const OperatorBase& op_; const OperatorBase& op_;
const Scope& scope_; const Scope& scope_;
const platform::DeviceContext& device_context_; const platform::DeviceContext& device_context_;
const RuntimeContext& ctx_;
}; };
template <> template <>
......
...@@ -122,7 +122,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase { ...@@ -122,7 +122,8 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(dev_place); auto& dev_ctx = *pool.Get(dev_place);
framework::ExecutionContext ctx(*this, scope, dev_ctx); framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope);
framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx);
const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids"); const LoDTensorArray* ids = ctx.Input<LoDTensorArray>("Ids");
const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores"); const LoDTensorArray* scores = ctx.Input<LoDTensorArray>("Scores");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册