C++预测模型的输入一维向量时应该如何reshape
Created by: wuxianshen
按照C++部署模型使用一维向量作为输入时运行报错:
主要报错信息 `---------------------- Error Message Summary:
InvalidArgumentError: Broadcast dimension mismatch. Operands could not be broadcast together with the shape of X = [80, 1, 1, 400] and the shape of Y = [400]. Received [80] in X is not equal to [400] in Y at i:0. [Hint: Expected x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || y_dims_array[i] <= 1 == true, but received x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || y_dims_array[i] <= 1:0 != true:1.] at (/home/tao/work/deep_learning/paddle-cpp/paddle/paddle/fluid/operators/elementwise/elementwise_op_function.h:157) ` input_t->Reshape({batch_size, channels, height, width}); 改为: input_t->Reshape({width}); 后,可以运行,但貌似batch设置无效,只计算一组输入。
环境配置 GIT COMMIT ID: 12f64401 WITH_MKL: ON WITH_MKLDNN: ON WITH_GPU: ON CUDA version: 10.1 CUDNN version: v7 WITH_TENSORRT: ON
System: Ubuntu 16.04,x86_64 Python version: 3.6.1 CUDA version: 10.1.243 cuDNN version: 7.6.5 Nvidia driver version: 418.87.00
TensorRT version: 6.0.1.5 (TensorRT-6.0.1.5.Ubuntu-16.04.x86_64-gnu.cuda-10.1.cudnn7.6.tar.gz)
报错信息 `Induced field model inference... WARNING: Logging before InitGoogleLogging() is written to STDERR I0627 11:34:22.342473 18407 analysis_predictor.cc:138] Profiler is deactivated, and no profiling report will be generated. I0627 11:34:22.346983 18407 analysis_predictor.cc:872] MODEL VERSION: 1.8.1 I0627 11:34:22.346990 18407 analysis_predictor.cc:874] PREDICTOR VERSION: 0.0.0 W0627 11:34:22.347020 18407 analysis_predictor.cc:887] - Version incompatible (1) elementwise_add W0627 11:34:22.347043 18407 analysis_predictor.cc:887] - Version incompatible (1) feed W0627 11:34:22.347048 18407 analysis_predictor.cc:887] - Version incompatible (1) fetch W0627 11:34:22.347050 18407 analysis_predictor.cc:887] - Version incompatible (2) matmul W0627 11:34:22.347054 18407 analysis_predictor.cc:887] - Version incompatible (1) scale W0627 11:34:22.347056 18407 analysis_predictor.cc:194] WARNING: Results may be DIFF! Please use the corresponding version of the model and prediction library, and do not use the develop branch. I0627 11:34:22.347095 18407 analysis_predictor.cc:432] TensorRT subgraph engine is enabled --- Running analysis [ir_graph_build_pass] --- Running analysis [ir_graph_clean_pass] --- Running analysis [ir_analysis_pass] --- Running IR pass [conv_affine_channel_fuse_pass] --- Running IR pass [conv_eltwiseadd_affine_channel_fuse_pass] --- Running IR pass [shuffle_channel_detect_pass] --- Running IR pass [quant_conv2d_dequant_fuse_pass] --- Running IR pass [delete_quant_dequant_op_pass] --- Running IR pass [simplify_with_basic_ops_pass] --- Running IR pass [embedding_eltwise_layernorm_fuse_pass] --- Running IR pass [multihead_matmul_fuse_pass_v2] --- Running IR pass [skip_layernorm_fuse_pass] --- Running IR pass [conv_bn_fuse_pass] --- Running IR pass [fc_fuse_pass] --- Running IR pass [tensorrt_subgraph_pass] --- Running IR pass [conv_bn_fuse_pass] --- Running IR pass [conv_elementwise_add_act_fuse_pass] --- Running IR pass [conv_elementwise_add2_act_fuse_pass] --- Running IR pass [conv_elementwise_add_fuse_pass] --- Running IR pass [transpose_flatten_concat_fuse_pass] --- Running analysis [ir_params_sync_among_devices_pass] I0627 11:34:22.352695 18407 ir_params_sync_among_devices_pass.cc:41] Sync params from CPU to GPU --- Running analysis [adjust_cudnn_workspace_size_pass] --- Running analysis [inference_op_replace_pass] --- Running analysis [ir_graph_to_program_pass] I0627 11:34:22.355870 18407 analysis_predictor.cc:493] ======= optimize end ======= I0627 11:34:22.355886 18407 naive_executor.cc:95] --- skip [feed], feed -> feed_0 I0627 11:34:22.355924 18407 naive_executor.cc:95] --- skip [save_infer_model/scale_0.tmp_0], fetch -> fetch W0627 11:34:22.355954 18407 device_context.cc:252] Please NOTE: device: 0, CUDA Capability: 75, Driver API Version: 10.1, Runtime API Version: 10.1 W0627 11:34:22.358325 18407 device_context.cc:260] device: 0, cuDNN Version: 7.6. terminate called after throwing an instance of 'paddle::platform::EnforceNotMet' what():
C++ Call Stacks (More useful to developers):
0 std::__cxx11::basic_string<char, std::char_traits, std::allocator > paddle::platform::GetTraceBackString<std::__cxx11::basic_string<char, std::char_traits, std::allocator > >(std::__cxx11::basic_string<char, std::char_traits, std::allocator >&&, char const*, int) 1 paddle::platform::EnforceNotMet::EnforceNotMet(std::__cxx11::basic_string<char, std::char_traits, std::allocator > const&, char const*, int) 2 paddle::operators::GetBroadcastDimsArrays(paddle::framework::DDim const&, paddle::framework::DDim const&, int*, int*, int*, int, int) 3 paddle::operators::ElementwiseOp::InferShape(paddle::framework::InferShapeContext*) const 4 paddle::framework::OperatorWithKernel::RunImpl(paddle::framework::Scope const&, paddle::platform::Place const&, paddle::framework::RuntimeContext*) const 5 paddle::framework::OperatorWithKernel::RunImpl(paddle::framework::Scope const&, paddle::platform::Place const&) const 6 paddle::framework::OperatorBase::Run(paddle::framework::Scope const&, paddle::platform::Place const&) 7 paddle::framework::NaiveExecutor::Run() 8 paddle::AnalysisPredictor::ZeroCopyRun()
Error Message Summary:
InvalidArgumentError: Broadcast dimension mismatch. Operands could not be broadcast together with the shape of X = [80, 1, 1, 400] and the shape of Y = [400]. Received [80] in X is not equal to [400] in Y at i:0. [Hint: Expected x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || y_dims_array[i] <= 1 == true, but received x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || y_dims_array[i] <= 1:0 != true:1.] at (/home/tao/work/deep_learning/paddle-cpp/paddle/paddle/fluid/operators/elementwise/elementwise_op_function.h:157)
Aborted (core dumped) `
C++预测代码 `#include <gflags/gflags.h> #include <glog/logging.h> #include #include #include #include #include "paddle/include/paddle_inference_api.h"
namespace paddle{ using paddle::AnalysisConfig;
DEFINE_string(dirname, "../infer_model", "Directory of the inference model.");
using Time = decltype(std::chrono::high_resolution_clock::now());
Time time()
{
return std::chrono::high_resolution_clock::now();
}
double time_diff(Time t1, Time t2)
{
typedef std::chrono::microseconds ms;
auto diff = t2 - t1;
ms counter = std::chrono::duration_cast<ms>(diff);
return counter.count() / 1000.0f;
}
void prepare_trt_config(AnalysisConfig* config, int batch_size)
{
// 1. Model directory
// case 1: model_dir contains two files: model + params
//config->SetModel(FLAGS_dirname + "/model", FLAGS_dirname + "/params");
// case 2: model_dir contains one model file (__model__) and multiple params
config->SetModel(FLAGS_dirname);
// 2. Device
// config->DisableGpu();
config->EnableUseGpu(5000, 0);
// 3. General optimization
// Graph optimization, including operator fusion
//config->SwitchIrOptim(true);
// Memory reuse
//config->EnableMemoryOptim();
// We use ZeroCopyTensor here, so we set config->SwitchUseFeedFetchOps(false)
config->SwitchUseFeedFetchOps(false);
config->EnableTensorRtEngine(1 << 20, batch_size, 3, AnalysisConfig::Precision::kFloat32, false);
}
void test_induced_field(int batch_size, int repeat)
{
AnalysisConfig config;
prepare_trt_config(&config, batch_size);
auto predictor = CreatePaddlePredictor(config);
int channels = 1;
int height = 1;
int width = 8; // 8 floats as input vector
int input_num = channels * height * width * batch_size;
// prepare inputs
//float *input = new float[input_num];
//memset(input, 0, input_num * sizeof(float));
float input[] = {9.38945293e+00, 9.54218483e+00, 4.00937128e+00, 2.98900008e-01,
7.20000000e+02, 1.02042408e+01, 9.72727203e+00, 4.30936718e+00};
CHECK(sizeof(input) == width * sizeof(float));
float *raw_data = new float[8 * batch_size];
for (size_t batch_idx = 0; batch_idx < batch_size ; batch_idx ++ )
{
memcpy(raw_data + batch_idx * 8, input, sizeof(input));
}
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputTensor(input_names[0]);
input_t->Reshape({batch_size, channels, height, width});
//input_t->Reshape({width});
input_t->copy_from_cpu(raw_data);
// run
auto time1 = time();
CHECK(predictor->ZeroCopyRun());
auto time2 = time();
std::cout << "batch: " << batch_size
<< " predict cost: " << time_diff(time1, time2) / static_cast<float>(repeat)
<< " ms" << std::endl;
// get the output
std::vector<float> out_data;
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
out_data.resize(out_num);
output_t->copy_to_cpu(out_data.data());
for ( size_t j = 0; j < out_num; j++ )
{
LOG(INFO) << "output[ " << j << " ]: " << out_data[j];
}
}
}
int main()
{
std::cout << "Induced field model inference..." << std::endl;
paddle::test_induced_field(1, 1);
return 0;
}
网络定义:
import paddle.fluid as fluid
class FCNet(fluid.dygraph.Layer): def init(self, input_channels=8, output_channels=3): super(FCNet, self).init()
fc1_size = input_channels * 10
fc2_size = fc1_size * 5
self.fc1 = fluid.dygraph.Linear(
input_dim=input_channels,
output_dim=fc1_size,
dtype='float32',
act=None
)
self.fc2 = fluid.dygraph.Linear(
input_dim=fc1_size,
output_dim=fc2_size,
dtype='float32',
act=None
)
self.fc3 = fluid.dygraph.Linear(
input_dim=fc2_size,
output_dim=output_channels,
dtype='float32',
act=None
)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x`
动态图模型转静态图模型: `def convert_to_static_graph(reader, use_cuda): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() with fluid.dygraph.guard(place):
model = FCNet(input_channels = 8, output_channels=3)
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=0.0005, parameter_list=model.parameters())
#'''
para_state_dict, opti_state_dict = fluid.load_dygraph('model/best_model')
model.set_dict(para_state_dict)
optimizer.set_dict(opti_state_dict)
input = np.array((next(reader())[0][0])).astype('float32')
print(input)
input = fluid.dygraph.to_variable(input)
out_dygraph, static_layer = TracedLayer.trace(model, inputs=[input])
static_layer.save_inference_model('model/infer', feed=[0], fetch=[0])`