未验证 提交 2ccf77d1 编写于 作者: C chengduo 提交者: GitHub

Refine GetTensorFromVar (#14160)

* fix GetTensorFromVar
test=release/1.1

* refine GetTensorFromVar
test=develop
上级 d2e622f3
......@@ -354,18 +354,18 @@ void OperatorBase::GenerateTemporaryNames() {
}
}
static bool VarIsTensor(const Variable* var) {
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
static bool VarIsTensor(const Variable& var) {
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
}
const Tensor* GetTensorFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
return var->GetMutable<SelectedRows>()->mutable_value();
const Tensor* GetTensorFromVar(const Variable& var) {
if (var.IsType<LoDTensor>()) {
return static_cast<const Tensor*>(&(var.Get<LoDTensor>()));
} else if (var.IsType<SelectedRows>()) {
return &(var.Get<SelectedRows>().value());
} else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var->Type().name());
var.Type().name());
}
}
......@@ -415,8 +415,7 @@ bool ExecutionContext::HasOutput(const std::string& name) const {
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name);
return var == nullptr ? nullptr
: GetTensorFromVar(const_cast<Variable*>(var));
return var == nullptr ? nullptr : GetTensorFromVar(*var);
}
template <>
......@@ -428,7 +427,7 @@ const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : GetTensorFromVar(var);
return var == nullptr ? nullptr : GetTensorFromVar(*var);
});
return res;
}
......@@ -770,8 +769,10 @@ void OperatorWithKernel::TransferInplaceVarsBack(
for (auto& var_name : inplace_vars) {
VLOG(3) << "share inplace var " + var_name + " back to it's original scope";
auto* original_tensor = GetMutableTensorFromVar(scope.FindVar(var_name));
auto* transformed_tensor =
GetTensorFromVar(transfer_scope.FindVar(var_name));
auto* var = transfer_scope.FindVar(var_name);
PADDLE_ENFORCE(var != nullptr, "The var[%s] should not be nullptr",
var_name);
auto* transformed_tensor = GetTensorFromVar(*var);
original_tensor->ShareDataWith(*transformed_tensor);
}
}
......@@ -784,11 +785,11 @@ Scope* OperatorWithKernel::TryTransferData(
for (auto& var_name : var_name_item.second) {
auto* var = scope.FindVar(var_name);
// Only tensor can be tranfer to another device.
if (var == nullptr || !VarIsTensor(var)) {
if (var == nullptr || !VarIsTensor(*var)) {
continue;
}
auto* tensor_in = GetTensorFromVar(var);
auto* tensor_in = GetTensorFromVar(*var);
if (!tensor_in->IsInitialized()) {
continue;
}
......
......@@ -63,7 +63,7 @@ inline std::string GradVarName(const std::string& var_name) {
}
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetTensorFromVar(Variable* var);
const Tensor* GetTensorFromVar(const Variable& var);
class OperatorBase;
class ExecutionContext;
......
......@@ -67,6 +67,7 @@ class SumOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto x_vars = ctx.MultiInputVar("X");
auto x_vars_name = ctx.Inputs("X");
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout{framework::DataLayout::kAnyLayout};
......@@ -81,10 +82,11 @@ class SumOp : public framework::OperatorWithKernel {
if (x_vars[0]->IsType<framework::LoDTensor>()) {
int dtype = -1;
for (auto& x_var : x_vars) {
for (size_t idx = 0; idx < x_vars.size(); ++idx) {
PADDLE_ENFORCE(x_vars[idx] != nullptr,
"Input var[%s] should not be nullptr", x_vars_name[idx]);
// FIXME(zcd): The input x_var may be SelectedRows or LoDTensor.
auto tensor = framework::GetTensorFromVar(
const_cast<framework::Variable*>(x_var));
auto tensor = framework::GetTensorFromVar(*x_vars[idx]);
if (tensor->numel() == 0) {
continue;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册