未验证 提交 5036cf03 编写于 作者: Q QI JUN 提交者: GitHub

add helper function to get appropriate DeviceContext (#7066)

* add helper function to get appropriate DeviceContext
上级 a096c58e
......@@ -27,8 +27,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
using DataTransformFn =
std::function<void(const std::vector<platform::DeviceContext*> ctx,
using DataTransformFn = std::function<void(const platform::DeviceContext* ctx,
const Variable& in, Variable* out)>;
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
......
......@@ -54,18 +54,18 @@ auto kernel1 = GenFromBit({0, 0, 0, 1});
auto kernel2 = GenFromBit({0, 0, 1, 0});
auto kernel3 = GenFromBit({0, 0, 1, 1});
void TransDataType_t(std::vector<platform::DeviceContext*> ctx,
const Variable& in, Variable* out) {
void TransDataType_t(const platform::DeviceContext* ctx, const Variable& in,
Variable* out) {
test_value++;
}
void TransDataLayout_t(std::vector<platform::DeviceContext*> ctx,
const Variable& in, Variable* out) {
void TransDataLayout_t(const platform::DeviceContext* ctx, const Variable& in,
Variable* out) {
test_value--;
}
void TransLibraryType_t(std::vector<platform::DeviceContext*> ctx,
const Variable& in, Variable* out) {
void TransLibraryType_t(const platform::DeviceContext* ctx, const Variable& in,
Variable* out) {
test_value += 2;
}
......@@ -83,7 +83,8 @@ TEST(DataTransform, Register) {
using namespace paddle::platform;
auto& instance = DataTransformFnMap::Instance();
std::vector<DeviceContext*> ctx;
ASSERT_EQ(instance.Map().size(), 3UL);
DeviceContext* ctx = nullptr;
paddle::framework::Variable in;
paddle::framework::Variable out;
......
......@@ -384,6 +384,24 @@ class RuntimeInferShapeContext : public InferShapeContext {
const Scope& scope_;
};
const platform::DeviceContext* GetDeviceContext(
framework::KernelTypePair& kernel_pair) {
auto& actual_kernel_key = kernel_pair.first;
auto& expected_kernel_key = kernel_pair.second;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
if (platform::is_gpu_place(actual_kernel_key.place_) &&
platform::is_cpu_place(expected_kernel_key.place_)) {
return pool.Get(actual_kernel_key.place_);
} else if (platform::is_cpu_place(actual_kernel_key.place_) &&
platform::is_gpu_place(expected_kernel_key.place_)) {
return pool.Get(expected_kernel_key.place_);
} else {
PADDLE_THROW(
"Currently, model parallelism is only supported between CPU and CUDA");
}
}
void OperatorWithKernel::Run(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
......@@ -418,9 +436,9 @@ void OperatorWithKernel::Run(const Scope& scope,
"CPU and other devices. For example, multi-GPU model "
"parallelism will failed.");
} else {
auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key);
const DataTransformFn* trans_fun =
DataTransformFnMap::Instance().GetNullable(
std::make_pair(actual_kernel_key, expected_kernel_key));
DataTransformFnMap::Instance().GetNullable(kernel_pair);
if (trans_fun) {
auto input_vars = this->InputVars();
// TODO(qijun) filter the input vars that do not need to be transformed
......@@ -437,22 +455,18 @@ void OperatorWithKernel::Run(const Scope& scope,
}
if (!need_trans.empty()) {
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
platform::DeviceContext* trans_dev_ctx = nullptr;
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};
auto trans_dev_ctx = GetDeviceContext(kernel_pair);
// Wait for transform starting
dev_ctx->Wait();
for (auto var_name : need_trans) {
(*trans_fun)(trans_dev_ctx_vec, *(scope.FindVar(var_name)),
(*trans_fun)(trans_dev_ctx, *(scope.FindVar(var_name)),
scope.FindVar(var_name + framework::KernelTypeToString(
expected_kernel_key)));
}
// Wait for data transform finishing
for (auto ctx : trans_dev_ctx_vec) {
ctx->Wait();
}
trans_dev_ctx->Wait();
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册