未验证 提交 9faf5a39 编写于 作者: Y yuyang18

Refactor Operator.cc, and clean code

上级 b0f98849
...@@ -74,10 +74,10 @@ void OperatorWithKernel::Run( ...@@ -74,10 +74,10 @@ void OperatorWithKernel::Run(
auto kernel_type_for_var = this->GetKernelTypeForVar(...); auto kernel_type_for_var = this->GetKernelTypeForVar(...);
if (kernel_type_for_var.place_ != expected_kernel_key.place_) { if (kernel_type_for_var.place_ != expected_kernel_key.place_) {
auto* trans_var = new_scope.Var(var_name); auto* trans_var = new_scope.Var(var_name);
auto* out = DataTransform(expected_kernel_key, auto* out = TransferData(expected_kernel_key,
kernel_type_for_var, kernel_type_for_var,
*tensor_in); *tensor_in);
CopyVariableWithTensor(...); SetTensorToVariable(...);
} }
} }
......
...@@ -21,14 +21,14 @@ limitations under the License. */ ...@@ -21,14 +21,14 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static void PassTensorData(Tensor* from, Tensor* to) { static void PassTensorData(Tensor *from, Tensor *to) {
to->ShareDataWith(*from); to->ShareDataWith(*from);
*from = Tensor(); *from = Tensor();
} }
void DataTransform(const OpKernelType& expected_kernel_type, void TransferData(const OpKernelType &expected_kernel_type,
const OpKernelType& kernel_type_for_var, const OpKernelType &kernel_type_for_var,
const Tensor& input_tensor, Tensor* output_tensor) { const Tensor &input_tensor, Tensor *output_tensor) {
bool transformed = false; bool transformed = false;
Tensor in; Tensor in;
in.ShareDataWith(input_tensor); in.ShareDataWith(input_tensor);
...@@ -89,17 +89,17 @@ void DataTransform(const OpKernelType& expected_kernel_type, ...@@ -89,17 +89,17 @@ void DataTransform(const OpKernelType& expected_kernel_type,
output_tensor->ShareDataWith(in); output_tensor->ShareDataWith(in);
} }
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor, void SetTensorToVariable(const Variable &in_var, const Tensor &tensor,
Variable* out_var) { Variable *out_var) {
if (in_var.IsType<LoDTensor>()) { if (in_var.IsType<LoDTensor>()) {
auto& in_lod_tensor = in_var.Get<LoDTensor>(); auto &in_lod_tensor = in_var.Get<LoDTensor>();
auto* tran_lod_tensor = out_var->GetMutable<LoDTensor>(); auto *tran_lod_tensor = out_var->GetMutable<LoDTensor>();
tran_lod_tensor->set_lod(in_lod_tensor.lod()); tran_lod_tensor->set_lod(in_lod_tensor.lod());
tran_lod_tensor->set_layout(in_lod_tensor.layout()); tran_lod_tensor->set_layout(in_lod_tensor.layout());
tran_lod_tensor->ShareDataWith(tensor); tran_lod_tensor->ShareDataWith(tensor);
} else if (in_var.IsType<SelectedRows>()) { } else if (in_var.IsType<SelectedRows>()) {
auto& in_selected_rows = in_var.Get<SelectedRows>(); auto &in_selected_rows = in_var.Get<SelectedRows>();
auto* trans_selected_rows = out_var->GetMutable<SelectedRows>(); auto *trans_selected_rows = out_var->GetMutable<SelectedRows>();
trans_selected_rows->set_height(in_selected_rows.height()); trans_selected_rows->set_height(in_selected_rows.height());
trans_selected_rows->set_rows(in_selected_rows.rows()); trans_selected_rows->set_rows(in_selected_rows.rows());
trans_selected_rows->mutable_value()->ShareDataWith(tensor); trans_selected_rows->mutable_value()->ShareDataWith(tensor);
......
...@@ -30,12 +30,15 @@ limitations under the License. */ ...@@ -30,12 +30,15 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void DataTransform(const OpKernelType& expected_kernel_type, void TransferData(const OpKernelType &expected_kernel_type,
const OpKernelType& kernel_type_for_var, const OpKernelType &kernel_type_for_var,
const Tensor& input_tensor, Tensor* out); const Tensor &input_tensor, Tensor *out);
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor, /**
Variable* out_var); * Set OutVar from InVar, except the tensor is shared with `tensor`
*/
void SetTensorToVariable(const Variable &in_var, const Tensor &tensor,
Variable *out_var);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -97,7 +97,7 @@ inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) { ...@@ -97,7 +97,7 @@ inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) {
return ret; return ret;
} }
inline bool TransFromNeeded(const OpKernelType& l, const OpKernelType& r) { inline bool NeedTransform(const OpKernelType& l, const OpKernelType& r) {
return (!platform::places_are_same_class(l.place_, r.place_)) || return (!platform::places_are_same_class(l.place_, r.place_)) ||
(l.data_type_ != r.data_type_) || (l.data_type_ != r.data_type_) ||
NeedTransformLayout(l.data_layout_, r.data_layout_); NeedTransformLayout(l.data_layout_, r.data_layout_);
......
...@@ -620,8 +620,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -620,8 +620,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
"There are no kernels which are registered in the %s operator.", type_); "There are no kernels which are registered in the %s operator.", type_);
} }
ExecutionContext ctx(*this, scope, *dev_ctx);
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the // TODO(dzhwinter) : kernel fallback mechanism will be added when all the
...@@ -631,7 +629,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -631,7 +629,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// Do selection // Do selection
// } // }
auto expected_kernel_key = this->GetExpectedKernelType(ctx); auto expected_kernel_key =
this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
...@@ -640,56 +639,34 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -640,56 +639,34 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
KernelTypeToString(expected_kernel_key)); KernelTypeToString(expected_kernel_key));
} }
// do data transform // do data transformScope &transfer_scope;
Scope& new_scope = scope.NewScope(); std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope =
TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars);
std::vector<std::string> inplace_vars; // exec scope is the scope that kernel actually executed on.
for (auto& var_name_item : this->Inputs()) { const Scope& exec_scope =
for (auto& var_name : var_name_item.second) { (transfer_scope == nullptr ? scope : *transfer_scope);
auto* var = scope.FindVar(var_name);
if (var && VarIsTensor(var)) { if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
auto* tensor_in = GetTensorFromVar(var); dev_ctx = pool.Get(expected_kernel_key.place_);
if (tensor_in->IsInitialized()) {
auto kernel_type_for_var = this->GetKernelTypeForVar(
var_name_item.first, *tensor_in, expected_kernel_key);
if (TransFromNeeded(kernel_type_for_var, expected_kernel_key)) {
auto out_var_names = OutputVars(true);
if (std::find(out_var_names.begin(), out_var_names.end(),
var_name) != out_var_names.end()) {
inplace_vars.push_back(var_name);
}
VLOG(3) << "Transform Variable " << var_name << " from "
<< kernel_type_for_var << " to " << expected_kernel_key;
auto* trans_var = new_scope.Var(var_name);
std::shared_ptr<Tensor> out(new Tensor);
DataTransform(expected_kernel_key, kernel_type_for_var, *tensor_in,
out.get());
CopyVariableWithTensor(*var, *(out.get()), trans_var);
}
}
}
}
} }
auto* new_dev_ctx = pool.Get(expected_kernel_key.place_); kernel_iter->second->Compute(ExecutionContext(*this, exec_scope, *dev_ctx));
kernel_iter->second->Compute(
ExecutionContext(*this, new_scope, *new_dev_ctx));
for (auto& var_name : inplace_vars) { if (!transfered_inplace_vars.empty()) {
VLOG(3) << "share inplace var " + var_name + " back to it's original scope"; // there is inplace variable has been transfered.
auto* original_tensor = GetMutableTensorFromVar(scope.FindVar(var_name)); TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope);
auto* transformed_tensor = GetTensorFromVar(new_scope.FindVar(var_name));
original_tensor->ShareDataWith(*transformed_tensor);
} }
/*For profiling/benchmark only*/ /*For profiling/benchmark only*/
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
new_dev_ctx->Wait(); dev_ctx->Wait();
} }
if (FLAGS_check_nan_inf) { if (FLAGS_check_nan_inf) {
for (auto& vname : OutputVars(true)) { for (auto& vname : OutputVars(true)) {
auto* var = new_scope.FindVar(vname); auto* var = exec_scope.FindVar(vname);
if (var == nullptr) continue; if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>()); CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
...@@ -697,6 +674,64 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -697,6 +674,64 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
} }
} }
void OperatorWithKernel::TransferInplaceVarsBack(
const Scope& scope, const std::vector<std::string>& inplace_vars,
const Scope& transfer_scope) const {
for (auto& var_name : inplace_vars) {
VLOG(3) << "share inplace var " + var_name + " back to it's original scope";
auto* original_tensor = GetMutableTensorFromVar(scope.FindVar(var_name));
auto* transformed_tensor =
GetTensorFromVar(transfer_scope.FindVar(var_name));
original_tensor->ShareDataWith(*transformed_tensor);
}
}
Scope* OperatorWithKernel::TryTransferData(
const Scope& scope, const OpKernelType& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars) const {
Scope* new_scope = nullptr;
for (auto& var_name_item : Inputs()) {
for (auto& var_name : var_name_item.second) {
auto* var = scope.FindVar(var_name);
// Only tensor can be tranfer to another device.
if (var == nullptr || !VarIsTensor(var)) {
continue;
}
auto* tensor_in = GetTensorFromVar(var);
if (!tensor_in->IsInitialized()) {
continue;
}
auto kernel_type_for_var = GetKernelTypeForVar(
var_name_item.first, *tensor_in, expected_kernel_key);
if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) {
continue;
}
auto out_var_names = OutputVars(true);
if (std::find(out_var_names.begin(), out_var_names.end(), var_name) !=
out_var_names.end()) {
transfered_inplace_vars->emplace_back(var_name);
}
VLOG(3) << "Transform Variable " << var_name << " from "
<< kernel_type_for_var << " to " << expected_kernel_key;
if (new_scope == nullptr) {
new_scope = &scope.NewScope();
}
auto* trans_var = new_scope->Var(var_name);
Tensor out;
TransferData(expected_kernel_key, kernel_type_for_var, *tensor_in, &out);
SetTensorToVariable(*var, out, trans_var);
}
}
return new_scope;
}
proto::VarType::Type OperatorWithKernel::IndicateDataType( proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
......
...@@ -384,6 +384,20 @@ class OperatorWithKernel : public OperatorBase { ...@@ -384,6 +384,20 @@ class OperatorWithKernel : public OperatorBase {
// same. // same.
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const; proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place) const final;
/**
* Transfer data from scope to a transfered scope. If there is no data need to
* be tranfered, it returns nullptr.
*
* * transfered_inplace_vars is a output vector.
*/
Scope* TryTransferData(
const Scope& scope, const OpKernelType& expected_kernel_key,
std::vector<std::string>* transfered_inplace_vars) const;
void TransferInplaceVarsBack(const Scope& scope,
const std::vector<std::string>& inplace_vars,
const Scope& exec_scope) const;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册