diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 73f894a3e20ab779f8607e63a67139b0e8cce79a..2191dd3783d5ed7bb59b96c70d38a72bb0b2fee7 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -27,7 +27,7 @@ limitations under the License. */ namespace paddle { namespace framework { -using DataTransformFN = +using DataTransformFn = std::function ctx, const Variable& in, Variable* out)>; using KernelTypePair = std::pair; @@ -47,7 +47,7 @@ struct KernelTypePairHash { }; using DataTransformMap = - std::unordered_map; + std::unordered_map; class DataTransformFnMap { public: @@ -58,25 +58,25 @@ class DataTransformFnMap { } void Insert(const OpKernelType& left, const OpKernelType& right, - const DataTransformFN& data_tranform_fn) { + const DataTransformFn& data_tranform_fn) { Insert(std::make_pair(left, right), data_tranform_fn); } void Insert(const KernelTypePair& kernel_type_pair, - const DataTransformFN& data_tranform_fn) { + const DataTransformFn& data_tranform_fn) { PADDLE_ENFORCE(!Has(kernel_type_pair), "KernelTypePair %s has been registered", ""); map_.insert({kernel_type_pair, data_tranform_fn}); } - const DataTransformFN& Get(const KernelTypePair& key_pair) const { + const DataTransformFn& Get(const KernelTypePair& key_pair) const { auto data_transformer = GetNullable(key_pair); PADDLE_ENFORCE_NOT_NULL(data_transformer, - "DataTransformFN should not be NULL"); + "DataTransformFn should not be NULL"); return *data_transformer; } - const DataTransformFN* GetNullable(const KernelTypePair& key_pair) const { + const DataTransformFn* GetNullable(const KernelTypePair& key_pair) const { auto it = map_.find(key_pair); if (it == map_.end()) { return nullptr; diff --git a/paddle/framework/op_kernel_type.h b/paddle/framework/op_kernel_type.h index 97b542e345feab0bab701dd967558ce23375dc7f..b06002096fb109da806809f7b908d9768cf095ba 100644 --- a/paddle/framework/op_kernel_type.h +++ b/paddle/framework/op_kernel_type.h @@ -68,6 +68,8 @@ struct OpKernelType { data_type_ == o.data_type_ && data_layout_ == o.data_layout_ && library_type_ == o.library_type_; } + + bool operator!=(const OpKernelType& o) const { return !(*this == o); } }; inline std::ostream& operator<<(std::ostream& os, @@ -78,5 +80,11 @@ inline std::ostream& operator<<(std::ostream& os, return os; } +inline std::string KernelTypeToString(const OpKernelType& kernel_key) { + std::ostringstream stream; + stream << kernel_key; + return stream.str(); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_kernel_type_test.cc b/paddle/framework/op_kernel_type_test.cc index dd048405007974667bbb8a052b77ab8b3aa4580e..649afeee8a846b0579545f2edff77e9dbe3b4dd8 100644 --- a/paddle/framework/op_kernel_type_test.cc +++ b/paddle/framework/op_kernel_type_test.cc @@ -26,10 +26,8 @@ TEST(OpKernelType, ToString) { OpKernelType op_kernel_type(DataType::FP32, CPUPlace(), DataLayout::kNCHW, LibraryType::kCUDNN); - std::ostringstream stream; - stream << op_kernel_type; ASSERT_EQ( - stream.str(), + paddle::framework::KernelTypeToString(op_kernel_type), "data_type[5]:data_layout[NCHW]:place[CPUPlace]:library_type[CUDNN]"); } diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 886f73e7b81c35cac573bd041e6462eb2111bf85..f48512b5c682698dae86593fb89a720eea503f7d 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -413,37 +413,51 @@ void OperatorWithKernel::Run(const Scope& scope, } if (actual_kernel_key == expected_kernel_key) { - kernel_iter->second->Compute(ctx); + PADDLE_ENFORCE_EQ(actual_kernel_key.place_, expected_kernel_key.place_, + "Currently, model parallelism is only supported between " + "CPU and other devices. For example, multi-GPU model " + "parallelism will failed."); } else { - Scope& op_scope = scope.NewScope(); - auto input_vars = this->InputVars(); - for (auto var_name : input_vars) { - op_scope.Var(var_name); - } - - // TODO(qijun) get appropriate DeviceContext from DeviceContext pool - platform::DeviceContext* trans_dev_ctx = nullptr; - std::vector trans_dev_ctx_vec{trans_dev_ctx}; + const DataTransformFn* trans_fun = + DataTransformFnMap::Instance().GetNullable( + std::make_pair(actual_kernel_key, expected_kernel_key)); + if (trans_fun) { + auto input_vars = this->InputVars(); + // TODO(qijun) filter the input vars that do not need to be transformed + + // filter vars that has been transformed + std::vector need_trans; + for (auto var_name : input_vars) { + auto var_name_trans = + var_name + framework::KernelTypeToString(expected_kernel_key); + if (!scope.FindVar(var_name_trans)) { + const_cast(scope).Var(var_name_trans); + need_trans.push_back(var_name); + } + } - // TODO(qijun) get appropriate DataTransformFN from global map - framework::DataTransformFN trans_fun = nullptr; + if (!need_trans.empty()) { + // TODO(qijun) get appropriate DeviceContext from DeviceContext pool + platform::DeviceContext* trans_dev_ctx = nullptr; + std::vector trans_dev_ctx_vec{trans_dev_ctx}; - // Wait for transform starting - dev_ctx->Wait(); + // Wait for transform starting + dev_ctx->Wait(); - for (auto var_name : input_vars) { - trans_fun(trans_dev_ctx_vec, *(scope.FindVar(var_name)), - op_scope.FindVar(var_name)); - } - // Wait for data transform finishing - for (auto ctx : trans_dev_ctx_vec) { - ctx->Wait(); + for (auto var_name : need_trans) { + (*trans_fun)(trans_dev_ctx_vec, *(scope.FindVar(var_name)), + scope.FindVar(var_name + framework::KernelTypeToString( + expected_kernel_key))); + } + // Wait for data transform finishing + for (auto ctx : trans_dev_ctx_vec) { + ctx->Wait(); + } + } } - - // Create a new ExecutionContext - ExecutionContext op_ctx(*this, op_scope, *dev_ctx); - kernel_iter->second->Compute(op_ctx); } + + kernel_iter->second->Compute(ctx); } OpKernelType OperatorWithKernel::GetActualKernelType(