提交 bbff0df3 编写于 作者: X Xin Pan

try cache variables

test=develop
上级 52bc4ee7
...@@ -278,7 +278,20 @@ std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ = ...@@ -278,7 +278,20 @@ std::shared_ptr<ngraph::runtime::Backend> NgraphEngine::backend_ =
ngraph::runtime::Backend::create("CPU"); ngraph::runtime::Backend::create("CPU");
void NgraphEngine::GetNgInputShape(std::shared_ptr<OperatorBase> op) { void NgraphEngine::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
op->RuntimeInferShape(scope_, place_); RuntimeContext ctx;
for (auto& var_name_item : op->Inputs()) {
std::vector<Variable*> input_vars = ctx.inputs[var_name_item.first];
for (auto& var_name : var_name_item.second) {
input_vars.push_back(scope_.FindVar(var_name));
}
}
for (auto& var_name_item : op->Outputs()) {
std::vector<Variable*> output_vars = ctx.outputs[var_name_item.first];
for (auto& var_name : var_name_item.second) {
output_vars.push_back(scope_.FindVar(var_name));
}
}
op->RuntimeInferShape(scope_, place_, ctx);
for (auto& var_name_item : op->Inputs()) { for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
auto* var = scope_.FindVar(var_name); auto* var = scope_.FindVar(var_name);
......
...@@ -477,23 +477,22 @@ bool OpSupportGPU(const std::string& op_type) { ...@@ -477,23 +477,22 @@ bool OpSupportGPU(const std::string& op_type) {
class RuntimeInferShapeContext : public InferShapeContext { class RuntimeInferShapeContext : public InferShapeContext {
public: public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope,
: op_(op), scope_(scope) {} const RuntimeContext& ctx)
: op_(op), scope_(scope), ctx_(ctx) {}
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
// has only one input // has only one input
const auto& ins = op_.Inputs(); const auto& ins = ctx_.inputs;
auto it = ins.find(name); auto it = ins.find(name);
if (it == ins.end()) { if (it == ins.end()) {
return false; return false;
} }
const auto& in = it->second; const auto& in = it->second;
if (in.size() == 0 || in[0] == kEmptyVarName) { if (in.size() == 0) return false;
return false;
}
PADDLE_ENFORCE_EQ(in.size(), 1UL, PADDLE_ENFORCE_EQ(in.size(), 1UL,
"Input %s should not have more than one inputs", name); "Input %s should not have more than one inputs", name);
return scope_.FindVar(in[0]) != nullptr; return in[0] != nullptr;
} }
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
...@@ -678,6 +677,7 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -678,6 +677,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
private: private:
const OperatorBase& op_; const OperatorBase& op_;
const Scope& scope_; const Scope& scope_;
const RuntimeContext& ctx_;
}; };
static void CheckTensorNANOrInf(const std::string& name, static void CheckTensorNANOrInf(const std::string& name,
...@@ -696,8 +696,9 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -696,8 +696,9 @@ static void CheckTensorNANOrInf(const std::string& name,
} }
void OperatorWithKernel::RuntimeInferShape(const Scope& scope, void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place) const { const platform::Place& place,
RuntimeInferShapeContext infer_shape_ctx(*this, scope); const RuntimeContext& ctx) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope, ctx);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }
...@@ -743,10 +744,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -743,10 +744,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
KernelTypeToString(expected_kernel_key)); KernelTypeToString(expected_kernel_key));
} }
RuntimeContext ctx;
// do data transformScope &transfer_scope; // do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars; std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope = auto* transfer_scope =
TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars); PrepareData(scope, expected_kernel_key, &transfered_inplace_vars, &ctx);
// exec scope is the scope that kernel actually executed on. // exec scope is the scope that kernel actually executed on.
const Scope& exec_scope = const Scope& exec_scope =
...@@ -756,7 +758,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -756,7 +758,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
dev_ctx = pool.Get(expected_kernel_key.place_); dev_ctx = pool.Get(expected_kernel_key.place_);
} }
RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope); RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx)); kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
...@@ -797,13 +799,20 @@ void OperatorWithKernel::TransferInplaceVarsBack( ...@@ -797,13 +799,20 @@ void OperatorWithKernel::TransferInplaceVarsBack(
} }
} }
Scope* OperatorWithKernel::TryTransferData( Scope* OperatorWithKernel::PrepareData(
const Scope& scope, const OpKernelType& expected_kernel_key, const Scope& scope, const OpKernelType& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars) const { std::vector<std::string>* transfered_inplace_vars,
RuntimeContext* ctx) const {
Scope* new_scope = nullptr; Scope* new_scope = nullptr;
for (auto& var_name_item : Inputs()) { for (auto& var_name_item : Inputs()) {
for (auto& var_name : var_name_item.second) { std::vector<Variable*>& input_vars = ctx->inputs[var_name_item.first];
input_vars.resize(var_name_item.second.size());
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto& var_name = var_name_item.second[i];
auto* var = scope.FindVar(var_name); auto* var = scope.FindVar(var_name);
input_vars[i] = var;
// Only tensor can be tranfer to another device. // Only tensor can be tranfer to another device.
if (var == nullptr || !VarIsTensor(*var)) { if (var == nullptr || !VarIsTensor(*var)) {
continue; continue;
...@@ -851,12 +860,22 @@ Scope* OperatorWithKernel::TryTransferData( ...@@ -851,12 +860,22 @@ Scope* OperatorWithKernel::TryTransferData(
} }
auto* trans_var = new_scope->Var(var_name); auto* trans_var = new_scope->Var(var_name);
input_vars[i] = var;
Tensor out; Tensor out;
TransformData(expected_kernel_key, kernel_type_for_var, *tensor_in, &out); TransformData(expected_kernel_key, kernel_type_for_var, *tensor_in, &out);
SetTensorToVariable(*var, out, trans_var); SetTensorToVariable(*var, out, trans_var);
} }
} }
for (auto& var_name_item : Outputs()) {
std::vector<Variable*>& output_vars = ctx->outputs[var_name_item.first];
output_vars.resize(var_name_item.second.size());
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto& var_name = var_name_item.second[i];
output_vars[i] = scope.FindVar(var_name);
}
}
return new_scope; return new_scope;
} }
......
...@@ -70,6 +70,14 @@ Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); ...@@ -70,6 +70,14 @@ Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
class OperatorBase; class OperatorBase;
class ExecutionContext; class ExecutionContext;
class RuntimeContext {
public:
RuntimeContext() {}
VariableValueMap inputs;
VariableValueMap outputs;
};
/** /**
* OperatorBase has the basic elements that Net will call to do computation. * OperatorBase has the basic elements that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User * Only CreateOperator from OpRegistry will new Operator directly. User
...@@ -129,7 +137,8 @@ class OperatorBase { ...@@ -129,7 +137,8 @@ class OperatorBase {
void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; } void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
virtual void RuntimeInferShape(const Scope& scope, virtual void RuntimeInferShape(const Scope& scope,
const platform::Place& place) const {} const platform::Place& place,
const RuntimeContext& ctx) const {}
protected: protected:
std::string type_; std::string type_;
...@@ -350,8 +359,8 @@ class OperatorWithKernel : public OperatorBase { ...@@ -350,8 +359,8 @@ class OperatorWithKernel : public OperatorBase {
OpInfoMap::Instance().Get(Type()).infer_shape_(ctx); OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
} }
void RuntimeInferShape(const Scope& scope, void RuntimeInferShape(const Scope& scope, const platform::Place& place,
const platform::Place& place) const override; const RuntimeContext& ctx) const override;
protected: protected:
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
...@@ -371,9 +380,10 @@ class OperatorWithKernel : public OperatorBase { ...@@ -371,9 +380,10 @@ class OperatorWithKernel : public OperatorBase {
* *
* * transfered_inplace_vars is a output vector. * * transfered_inplace_vars is a output vector.
*/ */
Scope* TryTransferData( Scope* PrepareData(const Scope& scope,
const Scope& scope, const OpKernelType& expected_kernel_key, const OpKernelType& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars) const; std::vector<std::string>* transfered_inplace_vars,
RuntimeContext* ctx) const;
void TransferInplaceVarsBack(const Scope& scope, void TransferInplaceVarsBack(const Scope& scope,
const std::vector<std::string>& inplace_vars, const std::vector<std::string>& inplace_vars,
......
...@@ -28,8 +28,11 @@ class OperatorBase; ...@@ -28,8 +28,11 @@ class OperatorBase;
class OpDesc; class OpDesc;
class InferShapeContext; class InferShapeContext;
class BlockDesc; class BlockDesc;
class Variable;
using VariableNameMap = std::map<std::string, std::vector<std::string>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// TODO(panyx0718): Replace vector with something like gtl::Vector.
using VariableValueMap = std::map<std::string, std::vector<Variable*>>;
// The order should be as same as framework.proto // The order should be as same as framework.proto
using Attribute = using Attribute =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册