提交 900fbb83 编写于 作者: N nhzlx

add params sync pass

上级 12e1719f
cc_library(ir_graph_build_pass SRCS ir_graph_build_pass.cc DEPS analysis_pass argument ir_pass_manager)
cc_library(ir_analysis_pass SRCS ir_analysis_pass.cc DEPS analysis_pass argument ir_pass_manager)
cc_library(analysis_passes SRCS passes.cc DEPS ir_graph_build_pass ir_analysis_pass)
cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc DEPS analysis_pass argument ir_pass_manager analysis_helper)
cc_library(analysis_passes SRCS passes.cc DEPS ir_graph_build_pass ir_analysis_pass ir_params_sync_among_devices_pass)
set(analysis_deps ${analysis_deps}
ir_graph_build_pass
......
......@@ -61,6 +61,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) {
void IrAnalysisComposePass::ApplyIrPasses(Argument *argument) {
std::vector<std::string> passes({
"ir_graph_build_pass", "ir_analysis_pass",
"ir_params_sync_among_devices_pass",
});
for (const auto &pass : passes) {
VLOG(2) << "Run pass " << pass;
......
......@@ -36,12 +36,7 @@ void IrGraphBuildPass::RunImpl(Argument *argument) {
// so that the parameters will on the same device, or they will keep copying
// between difference devices.
platform::Place place;
if (argument->use_gpu()) {
PADDLE_ENFORCE(argument->gpu_device_id_valid());
place = platform::CUDAPlace(argument->gpu_device_id());
} else {
place = platform::CPUPlace();
}
place = platform::CPUPlace();
if (argument->model_dir_valid()) {
auto program =
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace {
bool IsPersistable(const framework::VarDesc *var) {
if (var->Persistable() &&
var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
var->GetType() != framework::proto::VarType::FETCH_LIST) {
return true;
}
return false;
}
} // namespace
namespace inference {
namespace analysis {
void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
PADDLE_ENFORCE(argument->scope_valid());
PADDLE_ENFORCE(argument->use_gpu_valid());
platform::Place place;
// The parameters are on the cpu, therefore, synchronization is not necessary.
if (!argument->use_gpu()) return;
LOG(INFO) << "Sync params from CPU to GPU";
PADDLE_ENFORCE(argument->gpu_device_id_valid());
place = platform::CUDAPlace(argument->gpu_device_id());
auto *scope = argument->scope_ptr();
// Get the program which has been processed by several passes.
analysis_program_.reset(
new framework::ProgramDesc(argument->ir_analyzed_program()));
const auto &global_block = analysis_program_->Block(0);
// sync the params from cpu to gpu.
for (auto &var : global_block.AllVars()) {
if (IsPersistable(var)) {
std::string var_name = var->Name();
LOG(INFO) << var_name;
auto &t = inference::analysis::GetFromScope<framework::LoDTensor>(
*scope, var_name);
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(t.dims());
temp_tensor.mutable_data<float>(cpu_place);
// Copy the parameter data to a tmp tensor.
TensorCopySync(t, cpu_place, &temp_tensor);
// Reallocation the space on GPU
t.mutable_data<float>(place);
// Copy parameter data to newly allocated GPU space.
TensorCopySync(temp_tensor, place, &t);
}
}
}
std::string IrParamsSyncAmongDevicesPass::repr() const {
return "ir-params-sync-among-devices-pass";
}
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace inference {
namespace analysis {
/*
* Sync parameter from CPU to GPU.
*/
class IrParamsSyncAmongDevicesPass : public AnalysisPass {
public:
void RunImpl(Argument *argument) override;
std::string repr() const override;
private:
std::unique_ptr<framework::ProgramDesc> analysis_program_;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -16,6 +16,7 @@
#include "paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.cc"
#include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h"
#include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h"
namespace paddle {
namespace inference {
......@@ -27,6 +28,9 @@ PassRegistry::PassRegistry() {
std::unique_ptr<AnalysisPass>(new IrGraphBuildPass));
passes_.emplace("ir_analysis_compose_pass",
std::unique_ptr<AnalysisPass>(new IrAnalysisComposePass));
passes_.emplace(
"ir_params_sync_among_devices_pass",
std::unique_ptr<AnalysisPass>(new IrParamsSyncAmongDevicesPass));
}
} // namespace analysis
......
......@@ -116,12 +116,8 @@ class CpuPassStrategy : public PassStrategy {
class GpuPassStrategy : public PassStrategy {
public:
GpuPassStrategy() : PassStrategy({}) {
// TODO(NHZlX) Problem with Data synchronization between GPU and CPU
// When running in GPU mode, the parameters are all on GPU. But the
// opearations of "conv_bn_fuse_pass" are on CPU.
passes_.assign({
"infer_clean_graph_pass",
// "infer_clean_graph_pass", "conv_bn_fuse_pass",
"infer_clean_graph_pass", "conv_bn_fuse_pass",
});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册