未验证 提交 9df2d8b5 编写于 作者: Y Yan Chunwei 提交者: GitHub

test/add text-classification test (#13081)

上级 4907d093
......@@ -86,15 +86,24 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
}
op_desc.SetInput("Bias", {new_bias_var});
}
#undef GET_NODE
// Create temp variables.
scope->Var(name_scope + "/BatchedInput.new")
->GetMutable<framework::LoDTensor>();
scope->Var(name_scope + "/BatchCellPreAct.new")
->GetMutable<framework::LoDTensor>();
scope->Var(name_scope + "/BatchedGate.new")
->GetMutable<framework::LoDTensor>();
op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {});
op_desc.SetOutput("Hidden", {hidden_n->Name()});
op_desc.SetOutput("Cell", {cell_n->Name()});
op_desc.SetOutput("XX", {xx_n->Name()});
op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"});
op_desc.SetOutput("BatchedGate", {name_scope + "/BatchedGate.new"});
op_desc.SetOutput("BatchCellPreAct", {name_scope + "/BatchCellPreAct.new"});
op_desc.SetOutput("BatchedInput", {name_scope + "/BatchedInput.new"});
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes"));
// TODO(TJ): get from attr
......@@ -130,8 +139,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
int fusion_count{0};
auto fc_no_bias_handler = [&](
const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
#define GET_NODE(name__) \
std::string name__##key = name_scope + "/" + #name__; \
auto* name__##n = pattern->RetrieveNode(name__##key); \
......@@ -152,21 +161,24 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
if (with_fc_bias) {
GET_NODE(fc_bias);
GET_NODE(elementwise_add);
lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, fc_bias);
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes(
{mul_n, lstm_n, elementwise_add_n});
GraphSafeRemoveNodes(graph, marked_nodes);
} else {
lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, -1);
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes({mul_n, lstm_n});
GraphSafeRemoveNodes(graph, marked_nodes);
}
#undef GET_NODE
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes({mul_n, lstm_n});
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, fc_no_bias_handler);
gpd(graph, handler);
return fusion_count;
}
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
......
......@@ -73,7 +73,6 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
void GraphPatternDetector::operator()(Graph* graph,
GraphPatternDetector::handle_t handler) {
if (!MarkPDNodesInGraph(*graph)) {
LOG(INFO) << "Mark failed";
return;
}
......
......@@ -19,6 +19,9 @@
#endif
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/inference/analysis/dot.h"
......
......@@ -58,7 +58,7 @@ endif()
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor
ARGS --infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model
--infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt)
--infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt)
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc)
......@@ -74,7 +74,7 @@ inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc)
set(CHINESE_NER_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/chinese_ner_model.tar.gz")
set(CHINESE_NER_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/chinese_ner-data.txt.tar.gz")
set(CHINESE_NER_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/chinese_ner" CACHE PATH "Chinese ner model and data root." FORCE)
if (NOT EXISTS ${CHINESE_NER_INSTALL_DIR} AND WITH_TESTING)
if (NOT EXISTS ${CHINESE_NER_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE)
inference_download_and_uncompress(${CHINESE_NER_INSTALL_DIR} ${CHINESE_NER_MODEL_URL} "chinese_ner_model.tar.gz")
inference_download_and_uncompress(${CHINESE_NER_INSTALL_DIR} ${CHINESE_NER_DATA_URL} "chinese_ner-data.txt.tar.gz")
endif()
......@@ -87,7 +87,7 @@ inference_analysis_test(test_analyzer_ner SRCS analyzer_ner_tester.cc
set(LAC_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/lac_model.tar.gz")
set(LAC_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/lac_data.txt.tar.gz")
set(LAC_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/lac" CACHE PATH "LAC model and data root." FORCE)
if (NOT EXISTS ${LAC_INSTALL_DIR} AND WITH_TESTING)
if (NOT EXISTS ${LAC_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE)
inference_download_and_uncompress(${LAC_INSTALL_DIR} ${LAC_MODEL_URL} "lac_model.tar.gz")
inference_download_and_uncompress(${LAC_INSTALL_DIR} ${LAC_DATA_URL} "lac_data.txt.tar.gz")
endif()
......@@ -96,3 +96,15 @@ inference_analysis_test(test_analyzer_lac SRCS analyzer_lac_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api
ARGS --infer_model=${LAC_INSTALL_DIR}/model
--infer_data=${LAC_INSTALL_DIR}/data.txt)
set(TEXT_CLASSIFICATION_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/text-classification-Senta.tar.gz")
set(TEXT_CLASSIFICATION_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/text_classification" CACHE PATH "Text Classification model and data root." FORCE)
if (NOT EXISTS ${TEXT_CLASSIFICATION_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE)
inference_download_and_uncompress(${TEXT_CLASSIFICATION_INSTALL_DIR} ${TEXT_CLASSIFICATION_MODEL_URL} "text-classification-Senta.tar.gz")
endif()
inference_analysis_test(test_text_classification SRCS test_text_classification.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api analysis_predictor
ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/text-classification-Senta)
......@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/analyzer.h"
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
......@@ -41,20 +42,16 @@ class DfgPassManagerImpl final : public DfgPassManager {
public:
DfgPassManagerImpl() {
// TODO(Superjomn) set the key with pass reprs.
LOG(INFO)
<< "-----------------------------------------------------------------";
if (FLAGS_IA_enable_ir) {
AddPass("fluid-to-ir-pass", new FluidToIrPass);
} else {
if (!FLAGS_IA_enable_ir) {
AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass);
} else {
AddPass("fluid-to-ir-pass", new FluidToIrPass);
}
TryAddTensorRtPass();
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
if (!FLAGS_IA_output_storage_path.empty()) {
AddPass("model-store-pass", new ModelStorePass);
}
LOG(INFO)
<< "-----------------------------------------------------------------";
}
std::string repr() const override { return "dfg-pass-manager"; }
......@@ -101,19 +98,16 @@ class DfgPassManagerImpl final : public DfgPassManager {
Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) {
std::vector<std::string> passes;
for (auto& pass : all_ir_passes_) {
if (!disabled_ir_passes_.count(pass)) {
passes.push_back(pass);
passes.push_back("graph_viz_pass"); // add graphviz for debug.
}
}
passes.push_back("graph_viz_pass");
// Ugly support fluid-to-ir-pass
argument->Set(kFluidToIrPassesAttr,
new std::vector<std::string>({
// Manual update the passes here.
"graph_viz_pass", //
"infer_clean_graph_pass", "graph_viz_pass", //
"attention_lstm_fuse_pass", "graph_viz_pass", //
"fc_lstm_fuse_pass", "graph_viz_pass", //
"mul_lstm_fuse_pass", "graph_viz_pass", //
"seq_concat_fc_fuse_pass", "graph_viz_pass", //
"fc_fuse_pass", "graph_viz_pass" //
}));
argument->Set(kFluidToIrPassesAttr, new std::vector<std::string>(passes));
for (auto& x : data_) {
PADDLE_ENFORCE(x->Initialize(argument));
......@@ -122,6 +116,11 @@ void Analyzer::Run(Argument* argument) {
}
}
Analyzer& Analyzer::DisableIrPasses(const std::vector<std::string>& passes) {
disabled_ir_passes_.insert(passes.begin(), passes.end());
return *this;
}
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -36,16 +36,10 @@ limitations under the License. */
*/
#include <gflags/gflags.h>
#include "paddle/fluid/inference/analysis/flags.h"
#include "paddle/fluid/inference/analysis/pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
// flag if not available.
DECLARE_bool(IA_enable_tensorrt_subgraph_engine);
DECLARE_string(IA_graphviz_log_root);
DECLARE_string(IA_output_storage_path);
DECLARE_bool(IA_enable_ir);
namespace paddle {
namespace inference {
namespace analysis {
......@@ -57,7 +51,26 @@ class Analyzer : public OrderedRegistry<PassManager> {
void Run(Argument* argument);
Analyzer& DisableIrPasses(const std::vector<std::string>& passes);
DISABLE_COPY_AND_ASSIGN(Analyzer);
private:
// All avaiable IR passes.
// The bigger fuse comes first, so that the small operators prefer to be
// merged in a larger fuse op. The small fusion will not break the pattern of
// larger fusion.
const std::vector<std::string> all_ir_passes_{{
// Manual update the passes here.
"infer_clean_graph_pass", //
"attention_lstm_fuse_pass", //
"fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", //
"seq_concat_fc_fuse_pass", //
"fc_fuse_pass", //
}};
std::unordered_set<std::string> disabled_ir_passes_;
};
} // namespace analysis
......
......@@ -271,17 +271,22 @@ void TestDituRNNPrediction(const std::string &model_path,
const std::string &data_path, int batch_size,
bool use_analysis, bool activate_ir,
int num_times = 1) {
NativeConfig config;
AnalysisConfig config;
config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__";
config.param_file = FLAGS_infer_ditu_rnn_model + "/param";
config.use_gpu = false;
config.device = 0;
config.specify_input_name = true;
config.enable_ir_optim = activate_ir;
PADDLE_ENFORCE(config.ir_mode ==
AnalysisConfig::IrPassMode::kExclude); // default
config.ir_passes.clear(); // Do not exclude any pass.
auto base_predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
auto predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kAnalysis>(config);
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
config);
std::vector<PaddleTensor> input_slots;
DataRecord data(data_path, batch_size);
// Prepare inputs.
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
// flag if not available.
DECLARE_bool(IA_enable_tensorrt_subgraph_engine);
DECLARE_string(IA_graphviz_log_root);
DECLARE_string(IA_output_storage_path);
DECLARE_bool(IA_enable_ir);
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/flags.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/pass.h"
......@@ -85,9 +86,11 @@ class FluidToIrPass final : public DataFlowGraphPass {
new Scope *(&argument_->Get<Scope>(ir::kParamScopeAttr)));
}
const auto &ir_passes_to_apply =
argument_->Get<std::vector<std::string>>(kFluidToIrPassesAttr);
ir_passes.Apply(ir_passes_to_apply);
if (FLAGS_IA_enable_ir) {
const auto &ir_passes_to_apply =
argument_->Get<std::vector<std::string>>(kFluidToIrPassesAttr);
ir_passes.Apply(ir_passes_to_apply);
}
PADDLE_ENFORCE(argument_->main_dfg.get());
argument_->main_dfg->Build(ir_passes.graph());
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <glog/logging.h> // use glog instead of PADDLE_ENFORCE to avoid importing other paddle header files.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/timer.h"
DEFINE_string(infer_model, "", "Directory of the inference model.");
DEFINE_string(infer_data, "", "Path of the dataset.");
DEFINE_int32(batch_size, 1, "batch size.");
DEFINE_int32(repeat, 1, "How many times to repeat run.");
namespace paddle {
template <typename T>
std::string to_string(const std::vector<T> &vec) {
std::stringstream ss;
for (const auto &c : vec) {
ss << c << " ";
}
return ss.str();
}
void PrintTime(const double latency, const int bs, const int repeat) {
LOG(INFO) << "===========profile result===========";
LOG(INFO) << "batch_size: " << bs << ", repeat: " << repeat
<< ", avg latency: " << latency / repeat << "ms";
LOG(INFO) << "=====================================";
}
void Main(int batch_size) {
// Three sequence inputs.
std::vector<PaddleTensor> input_slots(1);
// one batch starts
// data --
int64_t data0[] = {0, 1, 2};
for (auto &input : input_slots) {
input.data.Reset(data0, sizeof(data0));
input.shape = std::vector<int>({3, 1});
// dtype --
input.dtype = PaddleDType::INT64;
// LoD --
input.lod = std::vector<std::vector<size_t>>({{0, 3}});
}
// shape --
// Create Predictor --
AnalysisConfig config;
config.model_dir = FLAGS_infer_model;
config.use_gpu = false;
config.enable_ir_optim = true;
config.ir_passes.push_back("fc_lstm_fuse_pass");
auto predictor =
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
config);
inference::Timer timer;
double sum = 0;
std::vector<PaddleTensor> output_slots;
for (int i = 0; i < FLAGS_repeat; i++) {
timer.tic();
CHECK(predictor->Run(input_slots, &output_slots));
sum += timer.toc();
}
PrintTime(sum, batch_size, FLAGS_repeat);
// Get output
LOG(INFO) << "get outputs " << output_slots.size();
for (auto &output : output_slots) {
LOG(INFO) << "output.shape: " << to_string(output.shape);
// no lod ?
CHECK_EQ(output.lod.size(), 0UL);
LOG(INFO) << "output.dtype: " << output.dtype;
std::stringstream ss;
for (int i = 0; i < 5; i++) {
ss << static_cast<float *>(output.data.data())[i] << " ";
}
LOG(INFO) << "output.data summary: " << ss.str();
// one batch ends
}
}
TEST(text_classification, basic) { Main(FLAGS_batch_size); }
} // namespace paddle
USE_PASS(fc_fuse_pass);
USE_PASS(seq_concat_fc_fuse_pass);
USE_PASS(fc_lstm_fuse_pass);
USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass);
USE_PASS(attention_lstm_fuse_pass);
......@@ -44,7 +44,19 @@ function(inference_api_test TARGET_NAME)
endfunction(inference_api_test)
cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor)
cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api)
cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api
analysis
ir_pass_manager
pass
fc_fuse_pass
fc_lstm_fuse_pass
seq_concat_fc_fuse_pass
graph_viz_pass
infer_clean_graph_pass
graph_pattern_detector
infer_clean_graph_pass
attention_lstm_fuse_pass
)
cc_test(test_paddle_inference_api
SRCS api_tester.cc
......
......@@ -14,6 +14,8 @@
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
......@@ -28,6 +30,8 @@ bool AnalysisPredictor::Init(
VLOG(3) << "Predictor::init()";
if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device);
LOG(WARNING) << "ir optimize only supports CPU currently";
config_.enable_ir_optim = false;
} else {
place_ = paddle::platform::CPUPlace();
}
......@@ -73,7 +77,7 @@ bool AnalysisPredictor::Init(
void AnalysisPredictor::OptimizeInferenceProgram() {
LOG(INFO) << "optimize begin";
FLAGS_IA_enable_ir = true;
FLAGS_IA_enable_ir = config_.enable_ir_optim;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = ""; // Don't output the model.
// Analyze inference_program
......@@ -90,24 +94,26 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
}
argument_.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto()));
Analyzer().Run(&argument_);
PADDLE_ENFORCE(config_.ir_mode == AnalysisConfig::IrPassMode::kExclude,
"Only kExclude is supported yet.");
Analyzer().DisableIrPasses(config_.ir_passes).Run(&argument_);
CHECK(argument_.transformed_program_desc);
VLOG(5) << "to prepare executor";
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_.reset(
new framework::ProgramDesc(*argument_.transformed_program_desc));
PADDLE_ENFORCE(argument_.Has(framework::ir::kParamScopeAttr));
// Update scope.
scope_.reset(
argument_.Release<framework::Scope>(framework::ir::kParamScopeAttr));
LOG(INFO) << "optimize end ==";
if (argument_.Has(framework::ir::kParamScopeAttr)) {
// Update scope.
scope_.reset(
argument_.Release<framework::Scope>(framework::ir::kParamScopeAttr));
}
LOG(INFO) << "== optimize end ==";
}
template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
NativeConfig, PaddleEngineKind::kAnalysis>(const NativeConfig& config) {
VLOG(3) << "create NativePredictor";
AnalysisConfig, PaddleEngineKind::kAnalysis>(const AnalysisConfig& config) {
VLOG(3) << "create AnalysisConfig";
if (config.use_gpu) {
// 1. GPU memeroy
PADDLE_ENFORCE_GT(
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
......@@ -28,7 +30,7 @@ using framework::proto::ProgramDesc;
*/
class AnalysisPredictor : public NativePaddlePredictor {
public:
explicit AnalysisPredictor(const NativeConfig& config)
explicit AnalysisPredictor(const AnalysisConfig& config)
: NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr<framework::Scope>& parent_scope);
......@@ -44,7 +46,7 @@ class AnalysisPredictor : public NativePaddlePredictor {
Argument& analysis_argument() { return argument_; }
private:
NativeConfig config_;
AnalysisConfig config_;
Argument argument_;
};
......
......@@ -176,7 +176,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
framework::Scope *scope) {
VLOG(3) << "Predictor::set_feed";
if (inputs.size() != feeds_.size()) {
LOG(ERROR) << "wrong feed input size.";
LOG(ERROR) << "wrong feed input size, need " << feeds_.size() << " but get "
<< inputs.size();
return false;
}
for (size_t i = 0; i < inputs.size(); ++i) {
......
......@@ -150,6 +150,21 @@ struct TensorRTConfig : public NativeConfig {
int workspace_size{1 << 30};
};
// NOTE WIP, not stable yet.
struct AnalysisConfig : public NativeConfig {
//
enum class IrPassMode {
kSystem, // Use system default passes, not customize.
kInclude, // Specify the passes in `ir_passes`.
kExclude // Specify the disabled passes in `ir_passes`.
};
bool enable_ir_optim = true;
IrPassMode ir_mode{IrPassMode::kExclude};
// attention lstm fuse works only on some specific models, disable as default.
std::vector<std::string> ir_passes{"attention_lstm_fuse_pass"};
};
// A factory to help create different predictors.
//
// FOR EXTENSION DEVELOPER:
......
......@@ -57,7 +57,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0);
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册