提交 53d558cd 编写于 作者: J JiabinYang

test=develop, polish code and merge develop

...@@ -129,10 +129,6 @@ cc_test(version_test SRCS version_test.cc DEPS version) ...@@ -129,10 +129,6 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
if(WITH_NGRAPH)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph)
endif(WITH_NGRAPH)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
......
...@@ -28,7 +28,7 @@ struct ExecutionStrategy { ...@@ -28,7 +28,7 @@ struct ExecutionStrategy {
// If we set this to 1, we will delete all variables when finish a batch. and // If we set this to 1, we will delete all variables when finish a batch. and
// this will loss 15%+ performance. // this will loss 15%+ performance.
// Please be aware about this parameters. // Please be aware about this parameters.
size_t num_iteration_per_drop_scope_{100}; size_t num_iteration_per_drop_scope_{1};
ExecutorType type_{kDefault}; ExecutorType type_{kDefault};
bool dry_run_{false}; bool dry_run_{false};
}; };
......
...@@ -1072,8 +1072,9 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -1072,8 +1072,9 @@ Scope* OperatorWithKernel::PrepareData(
proto::VarType::Type OperatorWithKernel::IndicateDataType( proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
proto::VarType::Type defaut_data_type = static_cast<proto::VarType::Type>(-1); proto::VarType::Type dafault_data_type =
proto::VarType::Type data_type = defaut_data_type; static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type;
for (auto& input : this->inputs_) { for (auto& input : this->inputs_) {
const std::vector<const Variable*> vars = ctx.MultiInputVar(input.first); const std::vector<const Variable*> vars = ctx.MultiInputVar(input.first);
for (size_t i = 0; i < vars.size(); ++i) { for (size_t i = 0; i < vars.size(); ++i) {
...@@ -1092,7 +1093,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -1092,7 +1093,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
input.first, i); input.first, i);
proto::VarType::Type tmp = t->type(); proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE( PADDLE_ENFORCE(
tmp == data_type || data_type == defaut_data_type, tmp == data_type || data_type == dafault_data_type,
"DataType of Paddle Op %s must be the same. Get (%d) != (%d)", "DataType of Paddle Op %s must be the same. Get (%d) != (%d)",
Type(), DataTypeToString(data_type), DataTypeToString(tmp)); Type(), DataTypeToString(data_type), DataTypeToString(tmp));
data_type = tmp; data_type = tmp;
...@@ -1100,7 +1101,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -1100,7 +1101,7 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
} }
} }
} }
PADDLE_ENFORCE(data_type != defaut_data_type, PADDLE_ENFORCE(data_type != dafault_data_type,
"DataType should be indicated by input"); "DataType should be indicated by input");
return data_type; return data_type;
} }
......
...@@ -133,7 +133,9 @@ struct Argument { ...@@ -133,7 +133,9 @@ struct Argument {
// Memory optimized related. // Memory optimized related.
DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool); DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);
DECL_ARGUMENT_FIELD(memory_optim_force_update, MemoryOptimForceUpdate, bool); DECL_ARGUMENT_FIELD(static_memory_optim, StaticMemoryOptim, bool);
DECL_ARGUMENT_FIELD(static_memory_optim_force_update,
StaticMemoryOptimForceUpdate, bool);
// Indicate which kind of sort algorithm is used for operators, the memory // Indicate which kind of sort algorithm is used for operators, the memory
// optimization relays on the sort algorithm. // optimization relays on the sort algorithm.
DECL_ARGUMENT_FIELD(memory_optim_sort_kind, MemoryOptimSortKind, int); DECL_ARGUMENT_FIELD(memory_optim_sort_kind, MemoryOptimSortKind, int);
......
...@@ -444,6 +444,26 @@ std::vector<std::map<std::string, std::vector<int>>> DeseralizeBatchVarShapes( ...@@ -444,6 +444,26 @@ std::vector<std::map<std::string, std::vector<int>>> DeseralizeBatchVarShapes(
return batch_shapes; return batch_shapes;
} }
// Replace the -1 in shape to a real number to fake the shape.
std::vector<std::map<std::string, std::vector<int>>> FakeBatchVarShapes(
const framework::ProgramDesc& program) {
std::vector<std::map<std::string, std::vector<int>>> res;
res.emplace_back();
auto& record = res.front();
const int fake_batch_size = 3;
for (auto* var : program.Block(0).AllVars()) {
if (var->GetType() ==
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) {
auto shape = var->GetShape();
for (auto& v : shape) {
if (v < 0) v = fake_batch_size;
}
record[var->Name()].assign(shape.begin(), shape.end());
}
}
return res;
}
// Calculate the average dim of each tensor from the batch shape cache. // Calculate the average dim of each tensor from the batch shape cache.
std::unordered_map<std::string, size_t> GetBatchAverageSize( std::unordered_map<std::string, size_t> GetBatchAverageSize(
const std::vector<std::map<std::string, std::vector<int>>>& batches) { const std::vector<std::map<std::string, std::vector<int>>>& batches) {
...@@ -478,6 +498,7 @@ std::vector<std::unordered_set<std::string>> AnalysisBatchShapesByBatchSize( ...@@ -478,6 +498,7 @@ std::vector<std::unordered_set<std::string>> AnalysisBatchShapesByBatchSize(
std::unordered_map<std::string, std::stringstream> var_batchsize_hashes; std::unordered_map<std::string, std::stringstream> var_batchsize_hashes;
for (auto& batch : batches) { for (auto& batch : batches) {
for (auto& ele : batch) { for (auto& ele : batch) {
PADDLE_ENFORCE(!ele.second.empty());
int batch_size = ele.second.front(); int batch_size = ele.second.front();
// TODO(Superjomn) might consume large memory here, use combine hash. // TODO(Superjomn) might consume large memory here, use combine hash.
var_batchsize_hashes[ele.first] << batch_size; var_batchsize_hashes[ele.first] << batch_size;
...@@ -538,9 +559,21 @@ std::vector<std::unordered_set<std::string>> AnalysisBatchShapesBySimilarSize( ...@@ -538,9 +559,21 @@ std::vector<std::unordered_set<std::string>> AnalysisBatchShapesBySimilarSize(
std::string MemoryOptimizePass::repr() const { return "memory optimize pass"; } std::string MemoryOptimizePass::repr() const { return "memory optimize pass"; }
std::pair<size_t, size_t> GetRange(
const std::unordered_map<std::string, size_t>& ave_size) {
auto res = std::make_pair(std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::min());
for (auto& item : ave_size) {
res.first = std::min(item.second, res.first);
res.second = std::max(item.second, res.second);
}
return res;
}
void MemoryOptimizePass::RunImpl(Argument* argument) { void MemoryOptimizePass::RunImpl(Argument* argument) {
// When force update, should not optimize memory. // When force update, should not optimize memory.
if (!argument->enable_memory_optim() || argument->memory_optim_force_update()) if (!argument->enable_memory_optim() ||
argument->static_memory_optim_force_update())
return; return;
graph_ = argument->main_graph_ptr(); graph_ = argument->main_graph_ptr();
...@@ -549,11 +582,26 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { ...@@ -549,11 +582,26 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
argument->model_program_path_valid() ? argument->model_program_path() argument->model_program_path_valid() ? argument->model_program_path()
: ""); : "");
VLOG(3) << "Load memory cache from " << path; VLOG(3) << "Load memory cache from " << path;
if (inference::IsFileExists(path)) { std::vector<std::map<std::string, std::vector<int>>> batches;
VLOG(4) << "Performing memory optimize";
auto batches = DeseralizeBatchVarShapes(path); if (argument->static_memory_optim() && inference::IsFileExists(path)) {
string::PrettyLogInfo("--- Performing static memory optimize");
batches = DeseralizeBatchVarShapes(path);
} else {
string::PrettyLogInfo("--- Performing dynamic memory optimize");
batches = FakeBatchVarShapes(argument->main_program());
}
auto var_batch_ave_size = GetBatchAverageSize(batches); auto var_batch_ave_size = GetBatchAverageSize(batches);
// Get min and max memory size.
const auto range = GetRange(var_batch_ave_size);
const int cluster_size = std::max(
static_cast<int>((range.second - range.first) / 100 /*cluster num*/),
1024);
const int cluster_size1 = std::max(
static_cast<int>((range.second - range.first) / 1000 /*cluster num*/),
1024);
std::unordered_map<std::string, Node*> tensor_nodes; std::unordered_map<std::string, Node*> tensor_nodes;
space_table_t space_table; space_table_t space_table;
CollectVarMemorySize(var_batch_ave_size, &tensor_nodes, &space_table); CollectVarMemorySize(var_batch_ave_size, &tensor_nodes, &space_table);
...@@ -564,6 +612,8 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { ...@@ -564,6 +612,8 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
std::vector<std::function<MemoryAllocation()>> strategies; std::vector<std::function<MemoryAllocation()>> strategies;
for (int sort_kind = 0; sort_kind < 2; sort_kind++) { for (int sort_kind = 0; sort_kind < 2; sort_kind++) {
if (argument->static_memory_optim()) {
// This strategy only make scene in static memory optimize.
strategies.emplace_back([&, sort_kind] { strategies.emplace_back([&, sort_kind] {
auto clustered_vars_by_batch_size = auto clustered_vars_by_batch_size =
AnalysisBatchShapesByBatchSize(batches); AnalysisBatchShapesByBatchSize(batches);
...@@ -572,22 +622,23 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { ...@@ -572,22 +622,23 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
space_table, &reuse_table, sort_kind, &allocation); space_table, &reuse_table, sort_kind, &allocation);
return allocation; return allocation;
}); });
}
strategies.emplace_back([&, sort_kind] { strategies.emplace_back([&, sort_kind] {
auto clustered_vars_by_ave_size = AnalysisBatchShapesBySimilarSize( auto clustered_vars_by_ave_size =
space_table, batches, 1024); // interval 1kb AnalysisBatchShapesBySimilarSize(space_table, batches, cluster_size);
MemoryAllocation allocation; MemoryAllocation allocation;
MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, space_table,
space_table, &reuse_table, sort_kind, &allocation); &reuse_table, sort_kind, &allocation);
return allocation; return allocation;
}); });
strategies.emplace_back([&, sort_kind] { strategies.emplace_back([&, sort_kind] {
auto clustered_vars_by_ave_size = AnalysisBatchShapesBySimilarSize( auto clustered_vars_by_ave_size =
space_table, batches, 1024 * 1024); // interval 1MB AnalysisBatchShapesBySimilarSize(space_table, batches, cluster_size1);
MemoryAllocation allocation; MemoryAllocation allocation;
MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, space_table,
space_table, &reuse_table, sort_kind, &allocation); &reuse_table, sort_kind, &allocation);
return allocation; return allocation;
}); });
...@@ -596,8 +647,8 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { ...@@ -596,8 +647,8 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
space_table, batches, space_table, batches,
std::numeric_limits<int>::max()); // no intervals std::numeric_limits<int>::max()); // no intervals
MemoryAllocation allocation; MemoryAllocation allocation;
MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, space_table,
space_table, &reuse_table, sort_kind, &allocation); &reuse_table, sort_kind, &allocation);
return allocation; return allocation;
}); });
} }
...@@ -615,19 +666,15 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { ...@@ -615,19 +666,15 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
} }
} }
if (!best_strategy) { if (!best_strategy) {
LOG(ERROR) LOG(ERROR) << "This model makes poor memory optimize, skip memory optimize";
<< "This model makes poor memory optimize, skip memory optimize";
return; return;
} }
auto memory_allocation = (*best_strategy)(); auto memory_allocation = (*best_strategy)();
string::PrettyLogH2( string::PrettyLogInfo(
"--- Saved %.2f%s memory for workspace(temporary variables)", "--- Saved %.2f%s memory for workspace(temporary variables)",
memory_allocation.GetSavingRatio() * 100, "%"); memory_allocation.GetSavingRatio() * 100, "%");
string::PrettyLogDetail("--- Allocated %d MB",
memory_allocation.allocated / 1024. / 1024.);
string::PrettyLogDetail("--- Saved %d MB",
memory_allocation.saved / 1024. / 1024.);
argument->main_graph().Set(framework::ir::kGraphToProgramVarsToRemove, argument->main_graph().Set(framework::ir::kGraphToProgramVarsToRemove,
new std::unordered_set<std::string>); new std::unordered_set<std::string>);
auto& vars2remove = auto& vars2remove =
...@@ -636,7 +683,6 @@ void MemoryOptimizePass::RunImpl(Argument* argument) { ...@@ -636,7 +683,6 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
PerformReusePlan(reuse_table, memory_allocation.sort_kind, &vars2remove); PerformReusePlan(reuse_table, memory_allocation.sort_kind, &vars2remove);
argument->SetMemoryOptimSortKind(memory_allocation.sort_kind); argument->SetMemoryOptimSortKind(memory_allocation.sort_kind);
}
} }
float MemoryOptimizePass::MemoryAllocation::GetSavingRatio() const { float MemoryOptimizePass::MemoryAllocation::GetSavingRatio() const {
......
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/analysis/analysis_pass.h" #include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h" #include "paddle/fluid/platform/port.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
...@@ -95,7 +95,8 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) { ...@@ -95,7 +95,8 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
CP_MEMBER(memory_pool_init_size_mb_); CP_MEMBER(memory_pool_init_size_mb_);
CP_MEMBER(enable_memory_optim_); CP_MEMBER(enable_memory_optim_);
CP_MEMBER(memory_optim_force_update_); CP_MEMBER(static_memory_optim_);
CP_MEMBER(static_memory_optim_force_update_);
// TensorRT releated. // TensorRT releated.
CP_MEMBER(use_tensorrt_); CP_MEMBER(use_tensorrt_);
CP_MEMBER(tensorrt_workspace_size_); CP_MEMBER(tensorrt_workspace_size_);
...@@ -238,7 +239,8 @@ std::string contrib::AnalysisConfig::SerializeInfoCache() { ...@@ -238,7 +239,8 @@ std::string contrib::AnalysisConfig::SerializeInfoCache() {
ss << tensorrt_min_subgraph_size_; ss << tensorrt_min_subgraph_size_;
ss << enable_memory_optim_; ss << enable_memory_optim_;
ss << memory_optim_force_update_; ss << static_memory_optim_;
ss << static_memory_optim_force_update_;
ss << use_mkldnn_; ss << use_mkldnn_;
for (auto &item : mkldnn_enabled_op_types_) ss << item; for (auto &item : mkldnn_enabled_op_types_) ss << item;
...@@ -278,9 +280,11 @@ float contrib::AnalysisConfig::fraction_of_gpu_memory_for_pool() const { ...@@ -278,9 +280,11 @@ float contrib::AnalysisConfig::fraction_of_gpu_memory_for_pool() const {
#endif #endif
} }
void contrib::AnalysisConfig::EnableMemoryOptim(bool force_update_cache) { void contrib::AnalysisConfig::EnableMemoryOptim(
bool static_optim, bool force_update_static_cache) {
enable_memory_optim_ = true; enable_memory_optim_ = true;
memory_optim_force_update_ = force_update_cache; static_memory_optim_ = static_optim;
static_memory_optim_force_update_ = force_update_static_cache;
Update(); Update();
} }
...@@ -300,4 +304,16 @@ void contrib::AnalysisConfig::SetModelBuffer(const char *prog_buffer, ...@@ -300,4 +304,16 @@ void contrib::AnalysisConfig::SetModelBuffer(const char *prog_buffer,
Update(); Update();
} }
NativeConfig contrib::AnalysisConfig::ToNativeConfig() const {
NativeConfig config;
config.model_dir = model_dir_;
config.prog_file = prog_file_;
config.param_file = params_file_;
config.use_gpu = use_gpu_;
config.device = device_id_;
config.fraction_of_gpu_memory = fraction_of_gpu_memory_for_pool();
config.specify_input_name = specify_input_name_;
return config;
}
} // namespace paddle } // namespace paddle
...@@ -298,15 +298,15 @@ void AnalysisPredictor::GetFetchOne(const framework::LoDTensor &fetch, ...@@ -298,15 +298,15 @@ void AnalysisPredictor::GetFetchOne(const framework::LoDTensor &fetch,
bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs, bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
framework::Scope *scope) { framework::Scope *scope) {
VLOG(3) << "Predictor::get_fetch"; VLOG(3) << "Predictor::get_fetch";
outputs->resize(fetchs_.size()); outputs->resize(fetches_.size());
for (size_t i = 0; i < fetchs_.size(); ++i) { for (size_t i = 0; i < fetches_.size(); ++i) {
int idx = boost::get<int>(fetchs_[i]->GetAttr("col")); int idx = boost::get<int>(fetches_[i]->GetAttr("col"));
PADDLE_ENFORCE((size_t)idx == i); PADDLE_ENFORCE((size_t)idx == i);
framework::LoDTensor &fetch = framework::LoDTensor &fetch =
framework::GetFetchVariable(*scope, "fetch", idx); framework::GetFetchVariable(*scope, "fetch", idx);
auto type = fetch.type(); auto type = fetch.type();
auto output = &(outputs->at(i)); auto output = &(outputs->at(i));
output->name = fetchs_[idx]->Input("X")[0]; output->name = fetches_[idx]->Input("X")[0];
if (type == framework::proto::VarType::FP32) { if (type == framework::proto::VarType::FP32) {
GetFetchOne<float>(fetch, output); GetFetchOne<float>(fetch, output);
output->dtype = PaddleDType::FLOAT32; output->dtype = PaddleDType::FLOAT32;
...@@ -327,7 +327,9 @@ void AnalysisPredictor::OptimizeInferenceProgram() { ...@@ -327,7 +327,9 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
argument_.SetUseGPU(config_.use_gpu()); argument_.SetUseGPU(config_.use_gpu());
argument_.SetGPUDeviceId(config_.gpu_device_id()); argument_.SetGPUDeviceId(config_.gpu_device_id());
argument_.SetEnableMemoryOptim(config_.enable_memory_optim()); argument_.SetEnableMemoryOptim(config_.enable_memory_optim());
argument_.SetMemoryOptimForceUpdate(config_.memory_optim_force_update_); argument_.SetStaticMemoryOptim(config_.static_memory_optim_);
argument_.SetStaticMemoryOptimForceUpdate(
config_.static_memory_optim_force_update_);
argument_.SetModelFromMemory(config_.model_from_memory_); argument_.SetModelFromMemory(config_.model_from_memory_);
// Analyze inference_program // Analyze inference_program
if (!config_.model_dir().empty()) { if (!config_.model_dir().empty()) {
...@@ -422,10 +424,10 @@ void AnalysisPredictor::PrepareFeedFetch() { ...@@ -422,10 +424,10 @@ void AnalysisPredictor::PrepareFeedFetch() {
feed_names_[op->Output("Out")[0]] = idx; feed_names_[op->Output("Out")[0]] = idx;
} else if (op->Type() == "fetch") { } else if (op->Type() == "fetch") {
int idx = boost::get<int>(op->GetAttr("col")); int idx = boost::get<int>(op->GetAttr("col"));
if (fetchs_.size() <= static_cast<size_t>(idx)) { if (fetches_.size() <= static_cast<size_t>(idx)) {
fetchs_.resize(idx + 1); fetches_.resize(idx + 1);
} }
fetchs_[idx] = op; fetches_[idx] = op;
} }
} }
} }
...@@ -638,12 +640,12 @@ bool AnalysisPredictor::need_collect_var_shapes_for_memory_optim() { ...@@ -638,12 +640,12 @@ bool AnalysisPredictor::need_collect_var_shapes_for_memory_optim() {
// check if the cache exists // check if the cache exists
if (!config_.enable_memory_optim()) { if (!config_.enable_memory_optim()) {
need = false; need = false;
} else if (config_.enable_memory_optim() && } else if (config_.static_memory_optim_ &&
!inference::IsFileExists(inference::analysis::GetMemoryCachePath( !inference::IsFileExists(inference::analysis::GetMemoryCachePath(
config_.model_dir(), config_.prog_file()))) { config_.model_dir(), config_.prog_file()))) {
need = true; need = true;
} else if (config_.enable_memory_optim() && } else if (config_.static_memory_optim_ &&
config_.memory_optim_force_update_) { config_.static_memory_optim_force_update_) {
need = true; need = true;
} }
......
...@@ -115,7 +115,7 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -115,7 +115,7 @@ class AnalysisPredictor : public PaddlePredictor {
std::shared_ptr<framework::ProgramDesc> inference_program_; std::shared_ptr<framework::ProgramDesc> inference_program_;
std::vector<framework::OpDesc *> feeds_; std::vector<framework::OpDesc *> feeds_;
std::map<std::string, size_t> feed_names_; std::map<std::string, size_t> feed_names_;
std::vector<framework::OpDesc *> fetchs_; std::vector<framework::OpDesc *> fetches_;
// Memory buffer for feed inputs. The temporary LoDTensor will cause serious // Memory buffer for feed inputs. The temporary LoDTensor will cause serious
// concurrency problems, wrong results and memory leak, so cache them. // concurrency problems, wrong results and memory leak, so cache them.
std::vector<framework::LoDTensor> feed_tensors_; std::vector<framework::LoDTensor> feed_tensors_;
......
...@@ -162,17 +162,7 @@ struct AnalysisConfig { ...@@ -162,17 +162,7 @@ struct AnalysisConfig {
/** Transform the AnalysisConfig to NativeConfig. /** Transform the AnalysisConfig to NativeConfig.
*/ */
NativeConfig ToNativeConfig() const { NativeConfig ToNativeConfig() const;
NativeConfig config;
config.model_dir = model_dir_;
config.prog_file = prog_file_;
config.param_file = params_file_;
config.use_gpu = use_gpu_;
config.device = device_id_;
config.fraction_of_gpu_memory = fraction_of_gpu_memory_for_pool();
config.specify_input_name = specify_input_name_;
return config;
}
/** Specify the operator type list to use MKLDNN acceleration. /** Specify the operator type list to use MKLDNN acceleration.
* @param op_list the operator type list. * @param op_list the operator type list.
*/ */
...@@ -195,7 +185,8 @@ struct AnalysisConfig { ...@@ -195,7 +185,8 @@ struct AnalysisConfig {
/** Turn on memory optimize /** Turn on memory optimize
* NOTE still in development, will release latter. * NOTE still in development, will release latter.
*/ */
void EnableMemoryOptim(bool force_update_cache = false); void EnableMemoryOptim(bool static_optim = false,
bool force_update_static_cache = false);
/** Tell whether the memory optimization is activated. */ /** Tell whether the memory optimization is activated. */
bool enable_memory_optim() const; bool enable_memory_optim() const;
...@@ -241,7 +232,8 @@ struct AnalysisConfig { ...@@ -241,7 +232,8 @@ struct AnalysisConfig {
// memory reuse related. // memory reuse related.
bool enable_memory_optim_{false}; bool enable_memory_optim_{false};
bool memory_optim_force_update_{false}; bool static_memory_optim_{false};
bool static_memory_optim_force_update_{false};
bool use_mkldnn_{false}; bool use_mkldnn_{false};
std::unordered_set<std::string> mkldnn_enabled_op_types_; std::unordered_set<std::string> mkldnn_enabled_op_types_;
......
...@@ -253,7 +253,7 @@ void compare(bool use_mkldnn = false) { ...@@ -253,7 +253,7 @@ void compare(bool use_mkldnn = false) {
} }
// Compare result of NativeConfig and AnalysisConfig with memory optimization. // Compare result of NativeConfig and AnalysisConfig with memory optimization.
TEST(Analyzer_dam, compare_with_memory_optim) { TEST(Analyzer_dam, compare_with_static_memory_optim) {
// The small dam will core in CI, but works in local. // The small dam will core in CI, but works in local.
if (FLAGS_max_turn_num == 9) { if (FLAGS_max_turn_num == 9) {
contrib::AnalysisConfig cfg, cfg1; contrib::AnalysisConfig cfg, cfg1;
...@@ -263,7 +263,7 @@ TEST(Analyzer_dam, compare_with_memory_optim) { ...@@ -263,7 +263,7 @@ TEST(Analyzer_dam, compare_with_memory_optim) {
SetInput(&input_slots_all); SetInput(&input_slots_all);
// Run the first time to force to update memory cache // Run the first time to force to update memory cache
SetConfig(&cfg); SetConfig(&cfg);
cfg.EnableMemoryOptim(true); cfg.EnableMemoryOptim(true, true /*force update*/);
CompareNativeAndAnalysis( CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
...@@ -271,7 +271,7 @@ TEST(Analyzer_dam, compare_with_memory_optim) { ...@@ -271,7 +271,7 @@ TEST(Analyzer_dam, compare_with_memory_optim) {
// Run second time to use the memory cache and perform memory optimization. // Run second time to use the memory cache and perform memory optimization.
SetConfig(&cfg1); SetConfig(&cfg1);
cfg1.EnableMemoryOptim(); cfg1.EnableMemoryOptim(true, false /*do not force update*/);
CompareNativeAndAnalysis( CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg1), reinterpret_cast<const PaddlePredictor::Config *>(&cfg1),
...@@ -279,6 +279,24 @@ TEST(Analyzer_dam, compare_with_memory_optim) { ...@@ -279,6 +279,24 @@ TEST(Analyzer_dam, compare_with_memory_optim) {
} }
} }
TEST(Analyzer_dam, compare_with_dynamic_memory_optim) {
// The small dam will core in CI, but works in local.
if (FLAGS_max_turn_num == 9) {
contrib::AnalysisConfig cfg, cfg1;
DataRecord data(FLAGS_infer_data, FLAGS_batch_size);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
// Run the first time to force to update memory cache
SetConfig(&cfg);
cfg.EnableMemoryOptim();
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
}
TEST(Analyzer_dam, compare) { compare(); } TEST(Analyzer_dam, compare) { compare(); }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
......
...@@ -54,6 +54,11 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -54,6 +54,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
// Async // Async
if (!sync_mode_) { if (!sync_mode_) {
VLOG(3) << "async process var: " << varname; VLOG(3) << "async process var: " << varname;
if (varname == BATCH_BARRIER_MESSAGE) {
PADDLE_THROW(
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE");
}
try { try {
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
scope); scope);
......
...@@ -39,27 +39,33 @@ void RPCServer::SavePort() const { ...@@ -39,27 +39,33 @@ void RPCServer::SavePort() const {
port_file.open(file_path); port_file.open(file_path);
port_file << selected_port_; port_file << selected_port_;
port_file.close(); port_file.close();
VLOG(4) << "selected port written to " << file_path; VLOG(3) << "selected port written to " << file_path;
} }
void RPCServer::WaitBarrier(const std::string& rpc_name) { void RPCServer::WaitBarrier(const std::string& rpc_name) {
VLOG(3) << "WaitBarrier in: " << rpc_name;
std::unique_lock<std::mutex> lock(this->mutex_); std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [this, &rpc_name] { barrier_cond_.wait(lock, [this, &rpc_name] {
return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) || return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) ||
exit_flag_.load()); exit_flag_.load());
}); });
VLOG(3) << "batch_barrier_: " << rpc_name << " " VLOG(3) << "WaitBarrier out: " << rpc_name
<< barrier_counter_[rpc_name]; << " counter: " << barrier_counter_[rpc_name];
} }
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
VLOG(4) << "RPCServer begin IncreaseBatchBarrier " << rpc_name; VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
// barrier msg should make sure that it's in the right cond(send|recv)
WaitCond(rpc_name);
int b = 0; int b = 0;
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name]; b = ++barrier_counter_[rpc_name];
VLOG(3) << rpc_name << " barrier_counter: " << b;
if (b >= client_num_) { if (b >= client_num_) {
lock.unlock(); lock.unlock();
VLOG(3) << "BatchBarrier counter reach " << client_num_ << " for "
<< rpc_name;
barrier_cond_.notify_all(); barrier_cond_.notify_all();
lock.lock(); lock.lock();
} }
...@@ -71,7 +77,7 @@ void RPCServer::Complete() { ...@@ -71,7 +77,7 @@ void RPCServer::Complete() {
client_num_--; client_num_--;
need_reset_all_vars_ = true; need_reset_all_vars_ = true;
VLOG(4) << "decrease client_num to: " << client_num_; VLOG(3) << "decrease client_num to: " << client_num_;
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) { if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
barrier_counter_[kRequestGet]--; barrier_counter_[kRequestGet]--;
} }
...@@ -105,8 +111,8 @@ void RPCServer::RegisterRPC(const std::string& rpc_name, ...@@ -105,8 +111,8 @@ void RPCServer::RegisterRPC(const std::string& rpc_name,
static int cond = -1; static int cond = -1;
rpc_cond_map_[rpc_name] = ++cond; rpc_cond_map_[rpc_name] = ++cond;
VLOG(4) << "RegisterRPC rpc_name:" << rpc_name << ", handler:" << handler VLOG(3) << "RegisterRPC rpc_name: " << rpc_name << ", handler: " << handler
<< ", cond:" << rpc_cond_map_[rpc_name]; << ", cond: " << rpc_cond_map_[rpc_name];
} }
void RPCServer::SetCond(const std::string& rpc_name) { void RPCServer::SetCond(const std::string& rpc_name) {
...@@ -120,7 +126,7 @@ void RPCServer::SetCond(const std::string& rpc_name) { ...@@ -120,7 +126,7 @@ void RPCServer::SetCond(const std::string& rpc_name) {
} }
void RPCServer::WaitCond(const std::string& rpc_name) { void RPCServer::WaitCond(const std::string& rpc_name) {
VLOG(4) << "RPCServer WaitCond " << rpc_name; VLOG(3) << "RPCServer WaitCond in " << rpc_name;
int cond = 0; int cond = 0;
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
...@@ -130,6 +136,7 @@ void RPCServer::WaitCond(const std::string& rpc_name) { ...@@ -130,6 +136,7 @@ void RPCServer::WaitCond(const std::string& rpc_name) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
rpc_cond_.wait( rpc_cond_.wait(
lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); }); lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); });
VLOG(3) << "RPCServer WaitCond out " << rpc_name;
} }
void RPCServer::RegisterVar(const std::string& var_name, void RPCServer::RegisterVar(const std::string& var_name,
...@@ -151,7 +158,7 @@ void RPCServer::RegisterVar(const std::string& var_name, ...@@ -151,7 +158,7 @@ void RPCServer::RegisterVar(const std::string& var_name,
} }
rpc_cond_.notify_all(); rpc_cond_.notify_all();
VLOG(4) << "RegisterVar context:" << h.String(); VLOG(3) << "RegisterVar context:" << h.String();
} }
void RPCServer::IncreaseVarBarrier(const std::string& var_name) { void RPCServer::IncreaseVarBarrier(const std::string& var_name) {
...@@ -167,11 +174,11 @@ void RPCServer::IncreaseVarBarrier(const std::string& var_name) { ...@@ -167,11 +174,11 @@ void RPCServer::IncreaseVarBarrier(const std::string& var_name) {
barrier_cond_.notify_all(); barrier_cond_.notify_all();
} }
VLOG(4) << "IncreaseVarBarrier context:" << h.String(); VLOG(3) << "IncreaseVarBarrier context:" << h.String();
} }
void RPCServer::WaitVarBarrier(const std::string& var_name) { void RPCServer::WaitVarBarrier(const std::string& var_name) {
VLOG(4) << "WaitBarrier var_name:" << var_name; VLOG(3) << "WaitVarBarrier var_name:" << var_name;
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
barrier_cond_.wait(lock, [&]() { barrier_cond_.wait(lock, [&]() {
...@@ -179,11 +186,11 @@ void RPCServer::WaitVarBarrier(const std::string& var_name) { ...@@ -179,11 +186,11 @@ void RPCServer::WaitVarBarrier(const std::string& var_name) {
exit_flag_.load()); exit_flag_.load());
}); });
VLOG(4) << "WaitBarrier context: " << var_map_[var_name].String(); VLOG(3) << "WaitVarBarrier context: " << var_map_[var_name].String();
} }
void RPCServer::SetVarCond(const std::string& var_name) { void RPCServer::SetVarCond(const std::string& var_name) {
VLOG(4) << "SetVarCond var_name:" << var_name; VLOG(3) << "SetVarCond var_name:" << var_name;
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (var_map_.find(var_name) != var_map_.end()) { if (var_map_.find(var_name) != var_map_.end()) {
...@@ -193,14 +200,14 @@ void RPCServer::SetVarCond(const std::string& var_name) { ...@@ -193,14 +200,14 @@ void RPCServer::SetVarCond(const std::string& var_name) {
} }
void RPCServer::WaitVarCond(const std::string& var_name) { void RPCServer::WaitVarCond(const std::string& var_name) {
VLOG(4) << "WaitVarCond var_name:" << var_name; VLOG(3) << "WaitVarCond var_name:" << var_name;
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
rpc_cond_.wait(lock, [=] { rpc_cond_.wait(lock, [=] {
return (var_map_.find(var_name) != var_map_.end() || exit_flag_.load()); return (var_map_.find(var_name) != var_map_.end() || exit_flag_.load());
}); });
VLOG(4) << "WaitVarCond var_name:" << var_name << " end"; VLOG(3) << "WaitVarCond var_name:" << var_name << " end";
} }
MonomerHandle RPCServer::GetMonomer(const std::string& var_name) { MonomerHandle RPCServer::GetMonomer(const std::string& var_name) {
......
...@@ -137,7 +137,9 @@ void ListenAndServOp::RunSyncLoop( ...@@ -137,7 +137,9 @@ void ListenAndServOp::RunSyncLoop(
while (true) { while (true) {
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
VLOG(3) << "wait all clients to send gradient";
rpc_service_->SetCond(distributed::kRequestSend); rpc_service_->SetCond(distributed::kRequestSend);
VLOG(3) << "wait all clients to send send_barrier";
rpc_service_->WaitBarrier(distributed::kRequestSend); rpc_service_->WaitBarrier(distributed::kRequestSend);
if (rpc_service_->IsExit()) { if (rpc_service_->IsExit()) {
...@@ -168,12 +170,16 @@ void ListenAndServOp::RunSyncLoop( ...@@ -168,12 +170,16 @@ void ListenAndServOp::RunSyncLoop(
} }
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program, ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
recv_scope); recv_scope);
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)"; VLOG(3) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
VLOG(3) << "ResetReceivedVars";
ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars()); ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
VLOG(3) << "wait all clients to get parameters back";
rpc_service_->SetCond(distributed::kRequestGet); rpc_service_->SetCond(distributed::kRequestGet);
VLOG(3) << "wait all clients to send fetch_barrier";
rpc_service_->WaitBarrier(distributed::kRequestGet); rpc_service_->WaitBarrier(distributed::kRequestGet);
VLOG(3) << "ResetBarrierCounter";
rpc_service_->ResetBarrierCounter(); rpc_service_->ResetBarrierCounter();
} // while(true) } // while(true)
} }
......
...@@ -43,12 +43,14 @@ class GridSampleOp : public framework::OperatorWithKernel { ...@@ -43,12 +43,14 @@ class GridSampleOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2."); PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0], PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
"Input(X) and Input(Grid) dims[0] should be equal."); "Input(X) and Input(Grid) dims[0] should be equal.");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
grid_dims[1], x_dims[2], grid_dims[1], x_dims[2],
"Input(X) dims[2] and Input(Grid) dims[1] should be equal."); "Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
grid_dims[2], x_dims[3], grid_dims[2], x_dims[3],
"Input(X) dims[3] and Input(Grid) dims[2] should be equal."); "Input(X) dims[3] and Input(Grid) dims[2] should be equal.");
}
ctx->SetOutputDim("Output", x_dims); ctx->SetOutputDim("Output", x_dims);
ctx->ShareLoD("X", "Output"); ctx->ShareLoD("X", "Output");
......
if(WITH_NGRAPH) if(WITH_NGRAPH)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph)
cc_library(ngraph_engine SRCS ngraph_engine.cc DEPS ngraph_bridge framework_proto) cc_library(ngraph_engine SRCS ngraph_engine.cc DEPS ngraph_bridge framework_proto)
op_library(ngraph_engine_op DEPS ngraph_engine op_registry op_info device_context) op_library(ngraph_engine_op DEPS ngraph_engine op_registry op_info device_context)
endif() endif()
...@@ -17,39 +17,39 @@ limitations under the License. */ ...@@ -17,39 +17,39 @@ limitations under the License. */
#include <vector> #include <vector>
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "paddle/fluid/framework/ngraph_bridge.h" #include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/ngraph/ngraph_ops.h" #include "paddle/fluid/operators/ngraph/ngraph_ops.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/ngraph_helper.h" #include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace operators {
namespace NG_OPS = paddle::operators::ngraphs; namespace NG_OPS = paddle::operators::ngraphs;
std::map<std::string, std::map<std::string,
std::function<void(const std::shared_ptr<OperatorBase>&, std::function<void(const std::shared_ptr<framework::OperatorBase>&,
std::shared_ptr<std::unordered_map< std::shared_ptr<std::unordered_map<
std::string, std::shared_ptr<ngraph::Node>>>)>> std::string, std::shared_ptr<ngraph::Node>>>)>>
NgraphBridge::NG_NODE_MAP = { NgraphBridge::NG_NODE_MAP = {
{"elementwise_add", NG_OPS::BuildElementwiseAddNode}, {"elementwise_add", NG_OPS::BuildElementwiseAddNode},
{"elementwise_add_grad", NG_OPS::BuildElementwiseAddGradNode}, {"elementwise_add_grad", NG_OPS::BuildElementwiseAddGradNode},
{"fill_constant", paddle::operators::ngraphs::BuildFillConstantNode}, {"fill_constant", NG_OPS::BuildFillConstantNode},
{"mean", paddle::operators::ngraphs::BuildMeanNode}, {"mean", NG_OPS::BuildMeanNode},
{"mean_grad", paddle::operators::ngraphs::BuildMeanGradNode}, {"mean_grad", NG_OPS::BuildMeanGradNode},
{"mul", paddle::operators::ngraphs::BuildMulNode}, {"mul", NG_OPS::BuildMulNode},
{"mul_grad", paddle::operators::ngraphs::BuildMulGradNode}, {"mul_grad", NG_OPS::BuildMulGradNode},
{"softmax", paddle::operators::ngraphs::BuildSoftmaxNode}, {"softmax", NG_OPS::BuildSoftmaxNode},
{"softmax_grad", paddle::operators::ngraphs::BuildSoftmaxGradNode}, {"softmax_grad", NG_OPS::BuildSoftmaxGradNode},
{"scale", paddle::operators::ngraphs::BuildScaleNode}, {"scale", NG_OPS::BuildScaleNode},
{"relu", paddle::operators::ngraphs::BuildUnaryNode<ngraph::op::Relu>}, {"relu", NG_OPS::BuildUnaryNode<ngraph::op::Relu>},
{"tanh", paddle::operators::ngraphs::BuildUnaryNode<ngraph::op::Tanh>}, {"tanh", NG_OPS::BuildUnaryNode<ngraph::op::Tanh>},
{"top_k", paddle::operators::ngraphs::BuildTopKNode}}; {"top_k", NG_OPS::BuildTopKNode}};
void NgraphBridge::BuildNgNode(const std::shared_ptr<OperatorBase>& op) { void NgraphBridge::BuildNgNode(
const std::shared_ptr<framework::OperatorBase>& op) {
auto& op_type = op->Type(); auto& op_type = op->Type();
NG_NODE_MAP[op_type](op, ngb_node_map_); NG_NODE_MAP[op_type](op, ngb_node_map_);
} }
} // namespace framework } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -21,16 +21,16 @@ limitations under the License. */ ...@@ -21,16 +21,16 @@ limitations under the License. */
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
namespace paddle { #include "paddle/fluid/framework/operator.h"
namespace framework {
class OperatorBase; namespace paddle {
namespace operators {
class NgraphBridge { class NgraphBridge {
public: public:
static std::map< static std::map<
std::string, std::string,
std::function<void(const std::shared_ptr<OperatorBase>&, std::function<void(const std::shared_ptr<framework::OperatorBase>&,
std::shared_ptr<std::unordered_map< std::shared_ptr<std::unordered_map<
std::string, std::shared_ptr<ngraph::Node>>>)>> std::string, std::shared_ptr<ngraph::Node>>>)>>
NG_NODE_MAP; NG_NODE_MAP;
...@@ -41,7 +41,7 @@ class NgraphBridge { ...@@ -41,7 +41,7 @@ class NgraphBridge {
var_node_map) var_node_map)
: ngb_node_map_(var_node_map) {} : ngb_node_map_(var_node_map) {}
void BuildNgNode(const std::shared_ptr<OperatorBase>& op); void BuildNgNode(const std::shared_ptr<framework::OperatorBase>& op);
private: private:
std::shared_ptr< std::shared_ptr<
...@@ -49,5 +49,5 @@ class NgraphBridge { ...@@ -49,5 +49,5 @@ class NgraphBridge {
ngb_node_map_; ngb_node_map_;
}; };
} // namespace framework } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -24,11 +24,11 @@ limitations under the License. */ ...@@ -24,11 +24,11 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/ngraph_bridge.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/operators/ngraph/ngraph_engine.h" #include "paddle/fluid/operators/ngraph/ngraph_engine.h"
namespace paddle { namespace paddle {
...@@ -88,15 +88,14 @@ static std::vector<std::vector<int>> NgraphOpIntervals( ...@@ -88,15 +88,14 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
int pivot = left; int pivot = left;
while (pivot < right) { while (pivot < right) {
auto op_type = ops.at(pivot)->Type(); auto op_type = ops.at(pivot)->Type();
if (paddle::framework::NgraphBridge::NG_NODE_MAP.find(op_type) == if (NgraphBridge::NG_NODE_MAP.find(op_type) ==
paddle::framework::NgraphBridge::NG_NODE_MAP.end()) { NgraphBridge::NG_NODE_MAP.end()) {
++pivot; ++pivot;
} else { } else {
int start = pivot, end = start; int start = pivot, end = start;
while (pivot < right && while (pivot < right &&
(paddle::framework::NgraphBridge::NG_NODE_MAP.find( (NgraphBridge::NG_NODE_MAP.find(ops.at(pivot)->Type()) !=
ops.at(pivot)->Type()) != NgraphBridge::NG_NODE_MAP.end())) {
paddle::framework::NgraphBridge::NG_NODE_MAP.end())) {
++pivot; ++pivot;
++end; ++end;
} }
...@@ -283,7 +282,7 @@ void NgraphEngine::BuildNgNodes() { ...@@ -283,7 +282,7 @@ void NgraphEngine::BuildNgNodes() {
} }
} }
} }
framework::NgraphBridge ngb(var_node_map_); NgraphBridge ngb(var_node_map_);
for (auto& op : fused_ops_) { for (auto& op : fused_ops_) {
ngb.BuildNgNode(op); ngb.BuildNgNode(op);
} }
......
...@@ -32,10 +32,13 @@ class Calibrator(object): ...@@ -32,10 +32,13 @@ class Calibrator(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.program = kwargs['program'] self.program = kwargs['program']
self.iterations = kwargs['iterations']
self.pretrained_model = kwargs['pretrained_model'] self.pretrained_model = kwargs['pretrained_model']
self.debug = kwargs['debug'] self.debug = kwargs['debug'] if 'debug' in kwargs else False
self.algo = kwargs['algo'] self.algo = kwargs['algo']
self.output = kwargs['output']
self.feed_var_names = kwargs['feed_var_names']
self.fetch_list = kwargs['fetch_list']
self.exe = kwargs['exe']
self._conv_input_var_name = [] self._conv_input_var_name = []
self._conv_output_var_name = [] self._conv_output_var_name = []
...@@ -54,17 +57,38 @@ class Calibrator(object): ...@@ -54,17 +57,38 @@ class Calibrator(object):
self._u8_output_var = [] self._u8_output_var = []
self._s8_output_var = [] self._s8_output_var = []
self._persistable_vars = [] self._persistable_vars = []
self._sampling_data = {}
def generate_sampling_program(self):
self.__init_analysis() self.__init_analysis()
self.__generate_output_program() self.__generate_output_program()
def generate_quantized_data(self, sampling_data): def save_int8_model(self):
self.__sampling(sampling_data) self.__sampling(self._sampling_data)
self.__save_scale() self.__save_scale()
self.__update_program() self.__update_program()
self.__update_output_program_attr() self.__update_output_program_attr()
self.__display_debug() self.__display_debug()
self.__save_offline_model()
def sample_data(self):
'''
Sampling the tensor data of variable.
'''
for i in self.sampling_program.list_vars():
if i.name in self.sampling_vars:
np_data = np.array(fluid.global_scope().find_var(i.name)
.get_tensor())
if i.name not in self._sampling_data:
self._sampling_data[i.name] = []
self._sampling_data[i.name].append(np_data)
def __save_offline_model(self):
'''
Save the quantized model to the disk.
'''
fluid.io.save_inference_model(self.output, self.feed_var_names,
self.fetch_list, self.exe,
self.sampling_program)
def __display_debug(self): def __display_debug(self):
if self.debug: if self.debug:
......
...@@ -26,7 +26,7 @@ import paddle.fluid.profiler as profiler ...@@ -26,7 +26,7 @@ import paddle.fluid.profiler as profiler
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
import math import math
sys.path.append('..') sys.path.append('..')
import int8_inference.utility as ut import int8_inference.utility as int8_utility
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -120,13 +120,13 @@ class TestCalibration(unittest.TestCase): ...@@ -120,13 +120,13 @@ class TestCalibration(unittest.TestCase):
def setUp(self): def setUp(self):
# TODO(guomingz): Put the download process in the cmake. # TODO(guomingz): Put the download process in the cmake.
# Download and unzip test data set # Download and unzip test data set
imagenet_dl_url = 'http://paddle-inference-dist.bj.bcebos.com/int8/calibration_test_data.tar.gz' imagenet_dl_url = 'http://paddle-inference-dist.cdn.bcebos.com/int8/calibration_test_data.tar.gz'
zip_file_name = imagenet_dl_url.split('/')[-1] zip_file_name = imagenet_dl_url.split('/')[-1]
cmd = 'rm -rf data {} && mkdir data && wget {} && tar xvf {} -C data'.format( cmd = 'rm -rf data {} && mkdir data && wget {} && tar xvf {} -C data'.format(
zip_file_name, imagenet_dl_url, zip_file_name) zip_file_name, imagenet_dl_url, zip_file_name)
os.system(cmd) os.system(cmd)
# resnet50 fp32 data # resnet50 fp32 data
resnet50_fp32_model_url = 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' resnet50_fp32_model_url = 'http://paddle-inference-dist.cdn.bcebos.com/int8/resnet50_int8_model.tar.gz'
resnet50_zip_name = resnet50_fp32_model_url.split('/')[-1] resnet50_zip_name = resnet50_fp32_model_url.split('/')[-1]
resnet50_unzip_folder_name = 'resnet50_fp32' resnet50_unzip_folder_name = 'resnet50_fp32'
cmd = 'rm -rf {} {} && mkdir {} && wget {} && tar xvf {} -C {}'.format( cmd = 'rm -rf {} {} && mkdir {} && wget {} && tar xvf {} -C {}'.format(
...@@ -135,8 +135,7 @@ class TestCalibration(unittest.TestCase): ...@@ -135,8 +135,7 @@ class TestCalibration(unittest.TestCase):
resnet50_zip_name, resnet50_unzip_folder_name) resnet50_zip_name, resnet50_unzip_folder_name)
os.system(cmd) os.system(cmd)
self.iterations = 100 self.iterations = 50
self.skip_batch_num = 5
def run_program(self, model_path, generate_int8=False, algo='direct'): def run_program(self, model_path, generate_int8=False, algo='direct'):
image_shape = [3, 224, 224] image_shape = [3, 224, 224]
...@@ -163,16 +162,15 @@ class TestCalibration(unittest.TestCase): ...@@ -163,16 +162,15 @@ class TestCalibration(unittest.TestCase):
print("Start calibration ...") print("Start calibration ...")
calibrator = ut.Calibrator( calibrator = int8_utility.Calibrator(
program=infer_program, program=infer_program,
pretrained_model=model_path, pretrained_model=model_path,
iterations=100, algo=algo,
debug=False, exe=exe,
algo=algo) output=int8_model,
feed_var_names=feed_dict,
sampling_data = {} fetch_list=fetch_targets)
calibrator.generate_sampling_program()
test_info = [] test_info = []
cnt = 0 cnt = 0
for batch_id, data in enumerate(val_reader()): for batch_id, data in enumerate(val_reader()):
...@@ -192,13 +190,7 @@ class TestCalibration(unittest.TestCase): ...@@ -192,13 +190,7 @@ class TestCalibration(unittest.TestCase):
feed_dict[1]: label}, feed_dict[1]: label},
fetch_list=fetch_targets) fetch_list=fetch_targets)
if generate_int8: if generate_int8:
for i in calibrator.sampling_program.list_vars(): calibrator.sample_data()
if i.name in calibrator.sampling_vars:
np_data = np.array(fluid.global_scope().find_var(i.name)
.get_tensor())
if i.name not in sampling_data:
sampling_data[i.name] = []
sampling_data[i.name].append(np_data)
test_info.append(np.mean(acc1) * len(data)) test_info.append(np.mean(acc1) * len(data))
cnt += len(data) cnt += len(data)
...@@ -209,9 +201,8 @@ class TestCalibration(unittest.TestCase): ...@@ -209,9 +201,8 @@ class TestCalibration(unittest.TestCase):
break break
if generate_int8: if generate_int8:
calibrator.generate_quantized_data(sampling_data) calibrator.save_int8_model()
fluid.io.save_inference_model(int8_model, feed_dict, fetch_targets,
exe, calibrator.sampling_program)
print( print(
"Calibration is done and the corresponding files were generated at {}". "Calibration is done and the corresponding files were generated at {}".
format(os.path.abspath("calibration_out"))) format(os.path.abspath("calibration_out")))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册