提交 de94d06e 编写于 作者: H HexToString

fix again

上级 bf9b60f2
...@@ -602,13 +602,13 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<FluidFamilyCore> { ...@@ -602,13 +602,13 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<FluidFamilyCore> {
} }
//set inputHandle //set inputHandle
BatchTensor* batchTensor_pointer_in = reinterpret_cast<const BatchTensor*>(in); const BatchTensor* batchTensor_pointer_in = reinterpret_cast<const BatchTensor*>(in);
for(int i =0; i< batchTensor_pointer_in->count();++i){ for(int i =0; i< batchTensor_pointer_in->count();++i){
Tensor tensor_in_batchTensor = (*batchTensor_pointer_in)[i]; Tensor tensor_in_batchTensor = (*batchTensor_pointer_in)[i];
auto lod_tensor_in = core.GetInputHandle(tensor_in_batchTensor.name); auto lod_tensor_in = core->GetInputHandle(tensor_in_batchTensor.name);
lod_tensor_in->SetLoD(tensor_in_batchTensor.lod); lod_tensor_in->SetLoD(tensor_in_batchTensor.lod);
lod_tensor_in->Reshape(tensor_in_batchTensor.shape); lod_tensor_in->Reshape(tensor_in_batchTensor.shape);
void* origin_data = tensor_in_batchTensor.data().data(); void* origin_data = tensor_in_batchTensor.data.data();
if(tensor_in_batchTensor.type == FLOAT32){ if(tensor_in_batchTensor.type == FLOAT32){
float* data = reinterpret_cast<float*>(origin_data); float* data = reinterpret_cast<float*>(origin_data);
lod_tensor_in->CopyFromCpu(data); lod_tensor_in->CopyFromCpu(data);
...@@ -627,9 +627,9 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<FluidFamilyCore> { ...@@ -627,9 +627,9 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<FluidFamilyCore> {
//get out and copy to void* out //get out and copy to void* out
BatchTensor* batchTensor_pointer_out = reinterpret_cast<BatchTensor*>(out); BatchTensor* batchTensor_pointer_out = reinterpret_cast<BatchTensor*>(out);
std::vector<std::string> outnames = core.GetOutputNames(); std::vector<std::string> outnames = core->GetOutputNames();
for (int i = 0; i < outnames.size(); ++i){ for (int i = 0; i < outnames.size(); ++i){
auto lod_tensor_out = core.GetOutputHandle(outnames[i]); auto lod_tensor_out = core->GetOutputHandle(outnames[i]);
std::vector<int> output_shape = lod_tensor_out->shape(); std::vector<int> output_shape = lod_tensor_out->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>()); int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
int dataType = lod_tensor_out->type(); int dataType = lod_tensor_out->type();
...@@ -653,7 +653,7 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<FluidFamilyCore> { ...@@ -653,7 +653,7 @@ class FluidInferEngine : public CloneDBReloadableInferEngine<FluidFamilyCore> {
} }
Tensor tensor_out; Tensor tensor_out;
tensor_out.name = outnames[i]; tensor_out.name = outnames[i];
tensor_out.type = dataType; tensor_out.type = DataType(dataType);
tensor_out.shape.assign(output_shape.begin(), output_shape.end()); tensor_out.shape.assign(output_shape.begin(), output_shape.end());
std::vector<std::vector<size_t>> out_lod = lod_tensor_out->lod(); std::vector<std::vector<size_t>> out_lod = lod_tensor_out->lod();
for (int li = 0; li < out_lod.size(); ++li) { for (int li = 0; li < out_lod.size(); ++li) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册