未验证 提交 2ccf77d1 编写于 作者: C chengduo 提交者: GitHub

Refine GetTensorFromVar (#14160)

* fix GetTensorFromVar
test=release/1.1

* refine GetTensorFromVar
test=develop
上级 d2e622f3
...@@ -354,18 +354,18 @@ void OperatorBase::GenerateTemporaryNames() { ...@@ -354,18 +354,18 @@ void OperatorBase::GenerateTemporaryNames() {
} }
} }
static bool VarIsTensor(const Variable* var) { static bool VarIsTensor(const Variable& var) {
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>(); return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
} }
const Tensor* GetTensorFromVar(Variable* var) { const Tensor* GetTensorFromVar(const Variable& var) {
if (var->IsType<LoDTensor>()) { if (var.IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>(); return static_cast<const Tensor*>(&(var.Get<LoDTensor>()));
} else if (var->IsType<SelectedRows>()) { } else if (var.IsType<SelectedRows>()) {
return var->GetMutable<SelectedRows>()->mutable_value(); return &(var.Get<SelectedRows>().value());
} else { } else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.", PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var->Type().name()); var.Type().name());
} }
} }
...@@ -415,8 +415,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const { ...@@ -415,8 +415,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
template <> template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const { const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name); auto* var = InputVar(name);
return var == nullptr ? nullptr return var == nullptr ? nullptr : GetTensorFromVar(*var);
: GetTensorFromVar(const_cast<Variable*>(var));
} }
template <> template <>
...@@ -428,7 +427,7 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>( ...@@ -428,7 +427,7 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
std::transform(names.begin(), names.end(), std::back_inserter(res), std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { [&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : GetTensorFromVar(var); return var == nullptr ? nullptr : GetTensorFromVar(*var);
}); });
return res; return res;
} }
...@@ -770,8 +769,10 @@ void OperatorWithKernel::TransferInplaceVarsBack( ...@@ -770,8 +769,10 @@ void OperatorWithKernel::TransferInplaceVarsBack(
for (auto& var_name : inplace_vars) { for (auto& var_name : inplace_vars) {
VLOG(3) << "share inplace var " + var_name + " back to it's original scope"; VLOG(3) << "share inplace var " + var_name + " back to it's original scope";
auto* original_tensor = GetMutableTensorFromVar(scope.FindVar(var_name)); auto* original_tensor = GetMutableTensorFromVar(scope.FindVar(var_name));
auto* transformed_tensor = auto* var = transfer_scope.FindVar(var_name);
GetTensorFromVar(transfer_scope.FindVar(var_name)); PADDLE_ENFORCE(var != nullptr, "The var[%s] should not be nullptr",
var_name);
auto* transformed_tensor = GetTensorFromVar(*var);
original_tensor->ShareDataWith(*transformed_tensor); original_tensor->ShareDataWith(*transformed_tensor);
} }
} }
...@@ -784,11 +785,11 @@ Scope* OperatorWithKernel::TryTransferData( ...@@ -784,11 +785,11 @@ Scope* OperatorWithKernel::TryTransferData(
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
auto* var = scope.FindVar(var_name); auto* var = scope.FindVar(var_name);
// Only tensor can be tranfer to another device. // Only tensor can be tranfer to another device.
if (var == nullptr || !VarIsTensor(var)) { if (var == nullptr || !VarIsTensor(*var)) {
continue; continue;
} }
auto* tensor_in = GetTensorFromVar(var); auto* tensor_in = GetTensorFromVar(*var);
if (!tensor_in->IsInitialized()) { if (!tensor_in->IsInitialized()) {
continue; continue;
} }
......
...@@ -63,7 +63,7 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -63,7 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
} }
proto::VarType::Type GetDataTypeOfVar(const Variable* var); proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetTensorFromVar(Variable* var); const Tensor* GetTensorFromVar(const Variable& var);
class OperatorBase; class OperatorBase;
class ExecutionContext; class ExecutionContext;
......
...@@ -67,6 +67,7 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -67,6 +67,7 @@ class SumOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto x_vars = ctx.MultiInputVar("X"); auto x_vars = ctx.MultiInputVar("X");
auto x_vars_name = ctx.Inputs("X");
framework::LibraryType library{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout{framework::DataLayout::kAnyLayout}; framework::DataLayout layout{framework::DataLayout::kAnyLayout};
...@@ -81,10 +82,11 @@ class SumOp : public framework::OperatorWithKernel { ...@@ -81,10 +82,11 @@ class SumOp : public framework::OperatorWithKernel {
if (x_vars[0]->IsType<framework::LoDTensor>()) { if (x_vars[0]->IsType<framework::LoDTensor>()) {
int dtype = -1; int dtype = -1;
for (auto& x_var : x_vars) { for (size_t idx = 0; idx < x_vars.size(); ++idx) {
PADDLE_ENFORCE(x_vars[idx] != nullptr,
"Input var[%s] should not be nullptr", x_vars_name[idx]);
// FIXME(zcd): The input x_var may be SelectedRows or LoDTensor. // FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
auto tensor = framework::GetTensorFromVar( auto tensor = framework::GetTensorFromVar(*x_vars[idx]);
const_cast<framework::Variable*>(x_var));
if (tensor->numel() == 0) { if (tensor->numel() == 0) {
continue; continue;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册