diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 2191dd3783d5ed7bb59b96c70d38a72bb0b2fee7..bd6d301c12e0611c5b01c3ff58869dbeb96b268e 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -27,9 +27,8 @@ limitations under the License. */ namespace paddle { namespace framework { -using DataTransformFn = - std::function ctx, - const Variable& in, Variable* out)>; +using DataTransformFn = std::function; using KernelTypePair = std::pair; struct KernelTypePairHash { diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc index 4e2141ecd2ebe35402a8a04613702a2f79f6a179..5f05e881fa16eead1dc690f85375706bf3cd3e6d 100644 --- a/paddle/framework/data_transform_test.cc +++ b/paddle/framework/data_transform_test.cc @@ -54,18 +54,18 @@ auto kernel1 = GenFromBit({0, 0, 0, 1}); auto kernel2 = GenFromBit({0, 0, 1, 0}); auto kernel3 = GenFromBit({0, 0, 1, 1}); -void TransDataType_t(std::vector ctx, - const Variable& in, Variable* out) { +void TransDataType_t(const platform::DeviceContext* ctx, const Variable& in, + Variable* out) { test_value++; } -void TransDataLayout_t(std::vector ctx, - const Variable& in, Variable* out) { +void TransDataLayout_t(const platform::DeviceContext* ctx, const Variable& in, + Variable* out) { test_value--; } -void TransLibraryType_t(std::vector ctx, - const Variable& in, Variable* out) { +void TransLibraryType_t(const platform::DeviceContext* ctx, const Variable& in, + Variable* out) { test_value += 2; } @@ -83,7 +83,8 @@ TEST(DataTransform, Register) { using namespace paddle::platform; auto& instance = DataTransformFnMap::Instance(); - std::vector ctx; + ASSERT_EQ(instance.Map().size(), 3UL); + DeviceContext* ctx = nullptr; paddle::framework::Variable in; paddle::framework::Variable out; diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index c0be11294c4a6b49ae4bc2f805f76e9f04508349..a3ce96c409675ad52a811586c736ca22b5c7e99e 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -384,6 +384,24 @@ class RuntimeInferShapeContext : public InferShapeContext { const Scope& scope_; }; +const platform::DeviceContext* GetDeviceContext( + framework::KernelTypePair& kernel_pair) { + auto& actual_kernel_key = kernel_pair.first; + auto& expected_kernel_key = kernel_pair.second; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + + if (platform::is_gpu_place(actual_kernel_key.place_) && + platform::is_cpu_place(expected_kernel_key.place_)) { + return pool.Get(actual_kernel_key.place_); + } else if (platform::is_cpu_place(actual_kernel_key.place_) && + platform::is_gpu_place(expected_kernel_key.place_)) { + return pool.Get(expected_kernel_key.place_); + } else { + PADDLE_THROW( + "Currently, model parallelism is only supported between CPU and CUDA"); + } +} + void OperatorWithKernel::Run(const Scope& scope, const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); @@ -418,9 +436,9 @@ void OperatorWithKernel::Run(const Scope& scope, "CPU and other devices. For example, multi-GPU model " "parallelism will failed."); } else { + auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key); const DataTransformFn* trans_fun = - DataTransformFnMap::Instance().GetNullable( - std::make_pair(actual_kernel_key, expected_kernel_key)); + DataTransformFnMap::Instance().GetNullable(kernel_pair); if (trans_fun) { auto input_vars = this->InputVars(); // TODO(qijun) filter the input vars that do not need to be transformed @@ -437,22 +455,18 @@ void OperatorWithKernel::Run(const Scope& scope, } if (!need_trans.empty()) { - // TODO(qijun) get appropriate DeviceContext from DeviceContext pool - platform::DeviceContext* trans_dev_ctx = nullptr; - std::vector trans_dev_ctx_vec{trans_dev_ctx}; + auto trans_dev_ctx = GetDeviceContext(kernel_pair); // Wait for transform starting dev_ctx->Wait(); for (auto var_name : need_trans) { - (*trans_fun)(trans_dev_ctx_vec, *(scope.FindVar(var_name)), + (*trans_fun)(trans_dev_ctx, *(scope.FindVar(var_name)), scope.FindVar(var_name + framework::KernelTypeToString( expected_kernel_key))); } // Wait for data transform finishing - for (auto ctx : trans_dev_ctx_vec) { - ctx->Wait(); - } + trans_dev_ctx->Wait(); } } }