提交 eaf8ba35 编写于 作者: X Xin Pan

change input

test=develop
上级 840e6729
...@@ -143,12 +143,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames, ...@@ -143,12 +143,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames,
for (auto& var_name_item : innames) { for (auto& var_name_item : innames) {
std::vector<Variable*>& input_vars = inputs[var_name_item.first]; std::vector<Variable*>& input_vars = inputs[var_name_item.first];
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
LOG(ERROR) << "first in " << var_name_item.first << ":" << var_name;
input_vars.push_back(scope.FindVar(var_name)); input_vars.push_back(scope.FindVar(var_name));
} }
} }
for (auto& var_name_item : outnames) { for (auto& var_name_item : outnames) {
std::vector<Variable*>& output_vars = outputs[var_name_item.first]; std::vector<Variable*>& output_vars = outputs[var_name_item.first];
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
LOG(ERROR) << "first out " << var_name_item.first << ":" << var_name;
output_vars.push_back(scope.FindVar(var_name)); output_vars.push_back(scope.FindVar(var_name));
} }
} }
...@@ -429,11 +431,52 @@ bool ExecutionContext::HasOutput(const std::string& name) const { ...@@ -429,11 +431,52 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
return var != nullptr; return var != nullptr;
} }
const Variable* ExecutionContext::InputVar(const std::string& name) const {
auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) return nullptr;
PADDLE_ENFORCE_LE(it->second.size(), 1UL,
"Operator %s's input %s should contain only one variable.",
op_.Type(), name);
return it->second.empty() ? nullptr : it->second[0];
}
Variable* ExecutionContext::OutputVar(const std::string& name) const {
auto opt = op_.Output(name);
return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
}
const Variable* ExecutionContext::FastInputVar(const std::string& name) const {
auto it = ctx_.inputs.find(name);
if (it == ctx_.inputs.end()) return nullptr;
PADDLE_ENFORCE_LE(it->second.size(), 1UL,
"Operator %s's input %s should contain only one variable.",
op_.Type(), name);
return it->second.empty() ? nullptr : it->second[0];
}
Variable* ExecutionContext::FastOutputVar(const std::string& name) const {
auto it = ctx_.outputs.find(name);
if (it == ctx_.outputs.end()) return nullptr;
PADDLE_ENFORCE_LE(it->second.size(), 1UL,
"Operator %s's output %s should contain only one variable.",
op_.Type(), name);
return it->second.empty() ? nullptr : it->second[0];
}
template <> template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const { const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
return Input<LoDTensor>(name); return Input<LoDTensor>(name);
} }
template <>
const Tensor* ExecutionContext::FastInput<Tensor>(
const std::string& name) const {
return FastInput<LoDTensor>(name);
}
template <> template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const { const std::string& name) const {
...@@ -458,6 +501,11 @@ Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const { ...@@ -458,6 +501,11 @@ Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
return Output<LoDTensor>(name); return Output<LoDTensor>(name);
} }
template <>
Tensor* ExecutionContext::FastOutput<Tensor>(const std::string& name) const {
return FastOutput<LoDTensor>(name);
}
template <> template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>( std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const { const std::string& name) const {
...@@ -822,6 +870,7 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -822,6 +870,7 @@ Scope* OperatorWithKernel::PrepareData(
auto& var_name = var_name_item.second[i]; auto& var_name = var_name_item.second[i];
auto* var = scope.FindVar(var_name); auto* var = scope.FindVar(var_name);
input_vars[i] = var; input_vars[i] = var;
LOG(ERROR) << "second in " << var_name_item.first << ":" << var_name;
// Only tensor can be tranfer to another device. // Only tensor can be tranfer to another device.
if (var == nullptr || !VarIsTensor(*var)) { if (var == nullptr || !VarIsTensor(*var)) {
...@@ -882,6 +931,7 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -882,6 +931,7 @@ Scope* OperatorWithKernel::PrepareData(
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto& var_name = var_name_item.second[i]; auto& var_name = var_name_item.second[i];
output_vars[i] = scope.FindVar(var_name); output_vars[i] = scope.FindVar(var_name);
LOG(ERROR) << "second out " << var_name_item.first << ":" << var_name;
} }
} }
......
...@@ -191,15 +191,9 @@ class ExecutionContext { ...@@ -191,15 +191,9 @@ class ExecutionContext {
return op_.Outputs(name).size(); return op_.Outputs(name).size();
} }
const Variable* InputVar(const std::string& name) const { const Variable* InputVar(const std::string& name) const;
auto ipt = op_.Input(name);
return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
}
Variable* OutputVar(const std::string& name) const { Variable* OutputVar(const std::string& name) const;
auto opt = op_.Output(name);
return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
}
const std::vector<const Variable*> MultiInputVar( const std::vector<const Variable*> MultiInputVar(
const std::string& name) const { const std::string& name) const {
...@@ -238,6 +232,22 @@ class ExecutionContext { ...@@ -238,6 +232,22 @@ class ExecutionContext {
return var == nullptr ? nullptr : var->GetMutable<T>(); return var == nullptr ? nullptr : var->GetMutable<T>();
} }
template <typename T>
const T* FastInput(const std::string& name) const {
auto* var = FastInputVar(name);
return var == nullptr ? nullptr : &var->Get<T>();
}
template <typename T>
T* FastOutput(const std::string& name) const {
auto var = FastOutputVar(name);
return var == nullptr ? nullptr : var->GetMutable<T>();
}
const Variable* FastInputVar(const std::string& name) const;
Variable* FastOutputVar(const std::string& name) const;
template <typename T> template <typename T>
const std::vector<const T*> MultiInput(const std::string& name) const { const std::vector<const T*> MultiInput(const std::string& name) const {
auto names = op_.Inputs(name); auto names = op_.Inputs(name);
...@@ -303,6 +313,10 @@ class ExecutionContext { ...@@ -303,6 +313,10 @@ class ExecutionContext {
template <> template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const; const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;
template <>
const Tensor* ExecutionContext::FastInput<Tensor>(
const std::string& name) const;
template <> template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
const std::string& name) const; const std::string& name) const;
...@@ -310,6 +324,9 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( ...@@ -310,6 +324,9 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
template <> template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const; Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;
template <>
Tensor* ExecutionContext::FastOutput<Tensor>(const std::string& name) const;
template <> template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>( std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const; const std::string& name) const;
......
...@@ -56,7 +56,7 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -56,7 +56,7 @@ class PReluOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(ctx.FastInput<Tensor>("X")->type(),
ctx.device_context()); ctx.device_context());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册