未验证 提交 5036cf03 编写于 作者: Q QI JUN 提交者: GitHub

add helper function to get appropriate DeviceContext (#7066)

* add helper function to get appropriate DeviceContext
上级 a096c58e
...@@ -27,8 +27,7 @@ limitations under the License. */ ...@@ -27,8 +27,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using DataTransformFn = using DataTransformFn = std::function<void(const platform::DeviceContext* ctx,
std::function<void(const std::vector<platform::DeviceContext*> ctx,
const Variable& in, Variable* out)>; const Variable& in, Variable* out)>;
using KernelTypePair = std::pair<OpKernelType, OpKernelType>; using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
......
...@@ -54,18 +54,18 @@ auto kernel1 = GenFromBit({0, 0, 0, 1}); ...@@ -54,18 +54,18 @@ auto kernel1 = GenFromBit({0, 0, 0, 1});
auto kernel2 = GenFromBit({0, 0, 1, 0}); auto kernel2 = GenFromBit({0, 0, 1, 0});
auto kernel3 = GenFromBit({0, 0, 1, 1}); auto kernel3 = GenFromBit({0, 0, 1, 1});
void TransDataType_t(std::vector<platform::DeviceContext*> ctx, void TransDataType_t(const platform::DeviceContext* ctx, const Variable& in,
const Variable& in, Variable* out) { Variable* out) {
test_value++; test_value++;
} }
void TransDataLayout_t(std::vector<platform::DeviceContext*> ctx, void TransDataLayout_t(const platform::DeviceContext* ctx, const Variable& in,
const Variable& in, Variable* out) { Variable* out) {
test_value--; test_value--;
} }
void TransLibraryType_t(std::vector<platform::DeviceContext*> ctx, void TransLibraryType_t(const platform::DeviceContext* ctx, const Variable& in,
const Variable& in, Variable* out) { Variable* out) {
test_value += 2; test_value += 2;
} }
...@@ -83,7 +83,8 @@ TEST(DataTransform, Register) { ...@@ -83,7 +83,8 @@ TEST(DataTransform, Register) {
using namespace paddle::platform; using namespace paddle::platform;
auto& instance = DataTransformFnMap::Instance(); auto& instance = DataTransformFnMap::Instance();
std::vector<DeviceContext*> ctx; ASSERT_EQ(instance.Map().size(), 3UL);
DeviceContext* ctx = nullptr;
paddle::framework::Variable in; paddle::framework::Variable in;
paddle::framework::Variable out; paddle::framework::Variable out;
......
...@@ -384,6 +384,24 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -384,6 +384,24 @@ class RuntimeInferShapeContext : public InferShapeContext {
const Scope& scope_; const Scope& scope_;
}; };
const platform::DeviceContext* GetDeviceContext(
framework::KernelTypePair& kernel_pair) {
auto& actual_kernel_key = kernel_pair.first;
auto& expected_kernel_key = kernel_pair.second;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
if (platform::is_gpu_place(actual_kernel_key.place_) &&
platform::is_cpu_place(expected_kernel_key.place_)) {
return pool.Get(actual_kernel_key.place_);
} else if (platform::is_cpu_place(actual_kernel_key.place_) &&
platform::is_gpu_place(expected_kernel_key.place_)) {
return pool.Get(expected_kernel_key.place_);
} else {
PADDLE_THROW(
"Currently, model parallelism is only supported between CPU and CUDA");
}
}
void OperatorWithKernel::Run(const Scope& scope, void OperatorWithKernel::Run(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
...@@ -418,9 +436,9 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -418,9 +436,9 @@ void OperatorWithKernel::Run(const Scope& scope,
"CPU and other devices. For example, multi-GPU model " "CPU and other devices. For example, multi-GPU model "
"parallelism will failed."); "parallelism will failed.");
} else { } else {
auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key);
const DataTransformFn* trans_fun = const DataTransformFn* trans_fun =
DataTransformFnMap::Instance().GetNullable( DataTransformFnMap::Instance().GetNullable(kernel_pair);
std::make_pair(actual_kernel_key, expected_kernel_key));
if (trans_fun) { if (trans_fun) {
auto input_vars = this->InputVars(); auto input_vars = this->InputVars();
// TODO(qijun) filter the input vars that do not need to be transformed // TODO(qijun) filter the input vars that do not need to be transformed
...@@ -437,22 +455,18 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -437,22 +455,18 @@ void OperatorWithKernel::Run(const Scope& scope,
} }
if (!need_trans.empty()) { if (!need_trans.empty()) {
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool auto trans_dev_ctx = GetDeviceContext(kernel_pair);
platform::DeviceContext* trans_dev_ctx = nullptr;
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};
// Wait for transform starting // Wait for transform starting
dev_ctx->Wait(); dev_ctx->Wait();
for (auto var_name : need_trans) { for (auto var_name : need_trans) {
(*trans_fun)(trans_dev_ctx_vec, *(scope.FindVar(var_name)), (*trans_fun)(trans_dev_ctx, *(scope.FindVar(var_name)),
scope.FindVar(var_name + framework::KernelTypeToString( scope.FindVar(var_name + framework::KernelTypeToString(
expected_kernel_key))); expected_kernel_key)));
} }
// Wait for data transform finishing // Wait for data transform finishing
for (auto ctx : trans_dev_ctx_vec) { trans_dev_ctx->Wait();
ctx->Wait();
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册