提交 1a44b2fb 编写于 作者: D dzhwinter

Merge remote-tracking branch 'origin/develop' into feature/ir_inplace_pass

...@@ -56,13 +56,6 @@ DECLARE_int32(paddle_num_threads); ...@@ -56,13 +56,6 @@ DECLARE_int32(paddle_num_threads);
namespace paddle { namespace paddle {
namespace inference { namespace inference {
float Random(float low, float high) {
static std::random_device rd;
static std::mt19937 mt(rd());
std::uniform_real_distribution<double> dist(low, high);
return dist(mt);
}
void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) { void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) {
const auto *analysis_config = const auto *analysis_config =
reinterpret_cast<const AnalysisConfig *>(config); reinterpret_cast<const AnalysisConfig *>(config);
...@@ -146,7 +139,8 @@ void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs, ...@@ -146,7 +139,8 @@ void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
const std::string &dirname, bool is_combined = true, const std::string &dirname, bool is_combined = true,
std::string model_filename = "model", std::string model_filename = "model",
std::string params_filename = "params", std::string params_filename = "params",
const std::vector<std::string> *feed_names = nullptr) { const std::vector<std::string> *feed_names = nullptr,
const int continuous_inuput_index = 0) {
// Set fake_image_data // Set fake_image_data
PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data."); PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data.");
std::vector<std::vector<int64_t>> feed_target_shapes = GetFeedTargetShapes( std::vector<std::vector<int64_t>> feed_target_shapes = GetFeedTargetShapes(
...@@ -183,7 +177,8 @@ void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs, ...@@ -183,7 +177,8 @@ void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
float *input_data = static_cast<float *>(input.data.data()); float *input_data = static_cast<float *>(input.data.data());
// fill input data, for profile easily, do not use random data here. // fill input data, for profile easily, do not use random data here.
for (size_t j = 0; j < len; ++j) { for (size_t j = 0; j < len; ++j) {
*(input_data + j) = Random(0.0, 1.0) / 10.; *(input_data + j) =
static_cast<float>((j + continuous_inuput_index) % len) / len;
} }
} }
(*inputs).emplace_back(input_slots); (*inputs).emplace_back(input_slots);
......
...@@ -119,9 +119,10 @@ void compare_continuous_input(std::string model_dir, bool use_tensorrt) { ...@@ -119,9 +119,10 @@ void compare_continuous_input(std::string model_dir, bool use_tensorrt) {
std::vector<std::vector<PaddleTensor>> inputs_all; std::vector<std::vector<PaddleTensor>> inputs_all;
if (!FLAGS_prog_filename.empty() && !FLAGS_param_filename.empty()) { if (!FLAGS_prog_filename.empty() && !FLAGS_param_filename.empty()) {
SetFakeImageInput(&inputs_all, model_dir, true, FLAGS_prog_filename, SetFakeImageInput(&inputs_all, model_dir, true, FLAGS_prog_filename,
FLAGS_param_filename); FLAGS_param_filename, nullptr, i);
} else { } else {
SetFakeImageInput(&inputs_all, model_dir, false, "__model__", ""); SetFakeImageInput(&inputs_all, model_dir, false, "__model__", "", nullptr,
i);
} }
CompareNativeAndAnalysis(native_pred.get(), analysis_pred.get(), CompareNativeAndAnalysis(native_pred.get(), analysis_pred.get(),
inputs_all); inputs_all);
......
...@@ -136,7 +136,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -136,7 +136,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace()); sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
auto sum_mat = EigenMatrix<T>::From(sum); auto sum_mat = EigenMatrix<T>::From(sum);
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto out_mat = framework::EigenVector<T>::Flatten(*out); auto out_mat = framework::EigenMatrix<T>::From(*out);
if (bias) { if (bias) {
bit_code->Add(*bias, pre_out); bit_code->Add(*bias, pre_out);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册