提交 42359e88 编写于 作者: T Tao Luo

clean code

test=develop
上级 923b1887
......@@ -304,7 +304,6 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
// NOTE All the members in AnalysisConfig should be copied to Argument.
void AnalysisPredictor::OptimizeInferenceProgram() {
LOG(INFO) << "optimization program";
status_program_optimized_ = true;
argument_.SetUseGPU(config_.use_gpu);
......@@ -313,11 +312,13 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
// Analyze inference_program
if (!config_.model_dir.empty()) {
argument_.SetModelDir(config_.model_dir);
} else if (!config_.param_file.empty() && !config_.prog_file.empty()) {
} else {
PADDLE_ENFORCE(
!config_.param_file.empty(),
"Either model_dir or (param_file, prog_file) should be set.");
PADDLE_ENFORCE(!config_.prog_file.empty());
argument_.SetModelProgramPath(config_.prog_file);
argument_.SetModelParamsPath(config_.param_file);
} else {
PADDLE_THROW("Either model_dir or (param_file, prog_file) should be set.");
}
if (config_.use_gpu && config_.use_tensorrt_) {
......
......@@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/operators/impl/load_combine.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/pybind/pybind.h"
......
cc_library(load_combine_impl SRCS load_combine.cc DEPS scope lod_tensor device_context op_registry data_type_transform)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/impl/load_combine.h"
namespace paddle {
namespace operators {
namespace impl {
void LoadParamsFromStream(const std::vector<std::string> &out_var_names,
const paddle::platform::Place &place,
bool load_as_fp16, std::istream *buffer,
const paddle::framework::Scope *scope) {
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
for (size_t i = 0; i < out_var_names.size(); i++) {
auto *out_var = scope->FindVar(out_var_names[i]);
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
out_var_names[i]);
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
// Get data from fin to tensor
DeserializeFromStream(*buffer, tensor, *dev_ctx);
auto in_dtype = framework::ToDataType(tensor->type());
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
// convert to float16 tensor
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor fp16_tensor;
// copy LoD info to the new tensor
fp16_tensor.set_lod(tensor->lod());
framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
&fp16_tensor);
// reset output tensor
out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor);
}
}
}
} // namespace impl
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
namespace impl {
// Load parameters from a single stream.
void LoadParamsFromStream(const std::vector<std::string> &out_var_names,
const platform::Place &place, bool load_as_fp16,
std::istream *buffer, const framework::Scope *scope);
} // namespace impl
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册