From eaf8ba35b519b780629a7108d08ffd3895ac18fe Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Tue, 18 Dec 2018 09:42:57 +0800 Subject: [PATCH] change input test=develop --- paddle/fluid/framework/operator.cc | 50 ++++++++++++++++++++++++++++++ paddle/fluid/framework/operator.h | 33 +++++++++++++++----- paddle/fluid/operators/prelu_op.cc | 2 +- 3 files changed, 76 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 461d357527..87f61f3afc 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -143,12 +143,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames, for (auto& var_name_item : innames) { std::vector& input_vars = inputs[var_name_item.first]; for (auto& var_name : var_name_item.second) { + LOG(ERROR) << "first in " << var_name_item.first << ":" << var_name; input_vars.push_back(scope.FindVar(var_name)); } } for (auto& var_name_item : outnames) { std::vector& output_vars = outputs[var_name_item.first]; for (auto& var_name : var_name_item.second) { + LOG(ERROR) << "first out " << var_name_item.first << ":" << var_name; output_vars.push_back(scope.FindVar(var_name)); } } @@ -429,11 +431,52 @@ bool ExecutionContext::HasOutput(const std::string& name) const { return var != nullptr; } +const Variable* ExecutionContext::InputVar(const std::string& name) const { + auto it = ctx_.inputs.find(name); + if (it == ctx_.inputs.end()) return nullptr; + + PADDLE_ENFORCE_LE(it->second.size(), 1UL, + "Operator %s's input %s should contain only one variable.", + op_.Type(), name); + return it->second.empty() ? nullptr : it->second[0]; +} + +Variable* ExecutionContext::OutputVar(const std::string& name) const { + auto opt = op_.Output(name); + return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt); +} + +const Variable* ExecutionContext::FastInputVar(const std::string& name) const { + auto it = ctx_.inputs.find(name); + if (it == ctx_.inputs.end()) return nullptr; + + PADDLE_ENFORCE_LE(it->second.size(), 1UL, + "Operator %s's input %s should contain only one variable.", + op_.Type(), name); + return it->second.empty() ? nullptr : it->second[0]; +} + +Variable* ExecutionContext::FastOutputVar(const std::string& name) const { + auto it = ctx_.outputs.find(name); + if (it == ctx_.outputs.end()) return nullptr; + + PADDLE_ENFORCE_LE(it->second.size(), 1UL, + "Operator %s's output %s should contain only one variable.", + op_.Type(), name); + return it->second.empty() ? nullptr : it->second[0]; +} + template <> const Tensor* ExecutionContext::Input(const std::string& name) const { return Input(name); } +template <> +const Tensor* ExecutionContext::FastInput( + const std::string& name) const { + return FastInput(name); +} + template <> const std::vector ExecutionContext::MultiInput( const std::string& name) const { @@ -458,6 +501,11 @@ Tensor* ExecutionContext::Output(const std::string& name) const { return Output(name); } +template <> +Tensor* ExecutionContext::FastOutput(const std::string& name) const { + return FastOutput(name); +} + template <> std::vector ExecutionContext::MultiOutput( const std::string& name) const { @@ -822,6 +870,7 @@ Scope* OperatorWithKernel::PrepareData( auto& var_name = var_name_item.second[i]; auto* var = scope.FindVar(var_name); input_vars[i] = var; + LOG(ERROR) << "second in " << var_name_item.first << ":" << var_name; // Only tensor can be tranfer to another device. if (var == nullptr || !VarIsTensor(*var)) { @@ -882,6 +931,7 @@ Scope* OperatorWithKernel::PrepareData( for (size_t i = 0; i < var_name_item.second.size(); ++i) { auto& var_name = var_name_item.second[i]; output_vars[i] = scope.FindVar(var_name); + LOG(ERROR) << "second out " << var_name_item.first << ":" << var_name; } } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index e359414d15..0aad91dbee 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -191,15 +191,9 @@ class ExecutionContext { return op_.Outputs(name).size(); } - const Variable* InputVar(const std::string& name) const { - auto ipt = op_.Input(name); - return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); - } + const Variable* InputVar(const std::string& name) const; - Variable* OutputVar(const std::string& name) const { - auto opt = op_.Output(name); - return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt); - } + Variable* OutputVar(const std::string& name) const; const std::vector MultiInputVar( const std::string& name) const { @@ -238,6 +232,22 @@ class ExecutionContext { return var == nullptr ? nullptr : var->GetMutable(); } + template + const T* FastInput(const std::string& name) const { + auto* var = FastInputVar(name); + return var == nullptr ? nullptr : &var->Get(); + } + + template + T* FastOutput(const std::string& name) const { + auto var = FastOutputVar(name); + return var == nullptr ? nullptr : var->GetMutable(); + } + + const Variable* FastInputVar(const std::string& name) const; + + Variable* FastOutputVar(const std::string& name) const; + template const std::vector MultiInput(const std::string& name) const { auto names = op_.Inputs(name); @@ -303,6 +313,10 @@ class ExecutionContext { template <> const Tensor* ExecutionContext::Input(const std::string& name) const; +template <> +const Tensor* ExecutionContext::FastInput( + const std::string& name) const; + template <> const std::vector ExecutionContext::MultiInput( const std::string& name) const; @@ -310,6 +324,9 @@ const std::vector ExecutionContext::MultiInput( template <> Tensor* ExecutionContext::Output(const std::string& name) const; +template <> +Tensor* ExecutionContext::FastOutput(const std::string& name) const; + template <> std::vector ExecutionContext::MultiOutput( const std::string& name) const; diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index 62c55c4f55..b6155ed3dd 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -56,7 +56,7 @@ class PReluOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType(ctx.Input("X")->type(), + return framework::OpKernelType(ctx.FastInput("X")->type(), ctx.device_context()); } }; -- GitLab