提交 0dd2e7f1 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #356 from codeWorm2015/develop

fix #355 remove unused codes
cmake_minimum_required(VERSION 3.0)
project(paddle-mobile)
add_definitions(-DPADDLE_MOBILE_DEBUG)
#add_definitions(-DPADDLE_MOBILE_DEBUG)
add_definitions(-DENABLE_EXCEPTION)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "log.h"
namespace paddle_mobile {}
......@@ -2638,147 +2638,3 @@ protobuf_c_boolean protobuf_c_message_check(const ProtobufCMessage *message) {
typedef void (*GenericHandler)(void *service, const ProtobufCMessage *input,
ProtobufCClosure closure, void *closure_data);
void protobuf_c_service_invoke_internal(ProtobufCService *service,
unsigned method_index,
const ProtobufCMessage *input,
ProtobufCClosure closure,
void *closure_data) {
GenericHandler *handlers;
GenericHandler handler;
/*
* Verify that method_index is within range. If this fails, you are
* likely invoking a newly added method on an old service. (Although
* other memory corruption bugs can cause this assertion too.)
*/
assert(method_index < service->descriptor->n_methods);
/*
* Get the array of virtual methods (which are enumerated by the
* generated code).
*/
handlers = (GenericHandler *)(service + 1);
/*
* Get our method and invoke it.
* \todo Seems like handler == NULL is a situation that needs handling.
*/
handler = handlers[method_index];
(*handler)(service, input, closure, closure_data);
}
void protobuf_c_service_generated_init(
ProtobufCService *service, const ProtobufCServiceDescriptor *descriptor,
ProtobufCServiceDestroy destroy) {
ASSERT_IS_SERVICE_DESCRIPTOR(descriptor);
service->descriptor = descriptor;
service->destroy = destroy;
service->invoke = protobuf_c_service_invoke_internal;
memset(service + 1, 0, descriptor->n_methods * sizeof(GenericHandler));
}
void protobuf_c_service_destroy(ProtobufCService *service) {
service->destroy(service);
}
/* --- querying the descriptors --- */
const ProtobufCEnumValue *protobuf_c_enum_descriptor_get_value_by_name(
const ProtobufCEnumDescriptor *desc, const char *name) {
unsigned start = 0;
unsigned count;
if (desc == NULL || desc->values_by_name == NULL) return NULL;
count = desc->n_value_names;
while (count > 1) {
unsigned mid = start + count / 2;
int rv = strcmp(desc->values_by_name[mid].name, name);
if (rv == 0)
return desc->values + desc->values_by_name[mid].index;
else if (rv < 0) {
count = start + count - (mid + 1);
start = mid + 1;
} else
count = mid - start;
}
if (count == 0) return NULL;
if (strcmp(desc->values_by_name[start].name, name) == 0)
return desc->values + desc->values_by_name[start].index;
return NULL;
}
const ProtobufCEnumValue *protobuf_c_enum_descriptor_get_value(
const ProtobufCEnumDescriptor *desc, int value) {
int rv = int_range_lookup(desc->n_value_ranges, desc->value_ranges, value);
if (rv < 0) return NULL;
return desc->values + rv;
}
const ProtobufCFieldDescriptor *protobuf_c_message_descriptor_get_field_by_name(
const ProtobufCMessageDescriptor *desc, const char *name) {
unsigned start = 0;
unsigned count;
const ProtobufCFieldDescriptor *field;
if (desc == NULL || desc->fields_sorted_by_name == NULL) return NULL;
count = desc->n_fields;
while (count > 1) {
unsigned mid = start + count / 2;
int rv;
field = desc->fields + desc->fields_sorted_by_name[mid];
rv = strcmp(field->name, name);
if (rv == 0)
return field;
else if (rv < 0) {
count = start + count - (mid + 1);
start = mid + 1;
} else
count = mid - start;
}
if (count == 0) return NULL;
field = desc->fields + desc->fields_sorted_by_name[start];
if (strcmp(field->name, name) == 0) return field;
return NULL;
}
const ProtobufCFieldDescriptor *protobuf_c_message_descriptor_get_field(
const ProtobufCMessageDescriptor *desc, unsigned value) {
int rv = int_range_lookup(desc->n_field_ranges, desc->field_ranges, value);
if (rv < 0) return NULL;
return desc->fields + rv;
}
const ProtobufCMethodDescriptor *
protobuf_c_service_descriptor_get_method_by_name(
const ProtobufCServiceDescriptor *desc, const char *name) {
unsigned start = 0;
unsigned count;
if (desc == NULL || desc->method_indices_by_name == NULL) return NULL;
count = desc->n_methods;
while (count > 1) {
unsigned mid = start + count / 2;
unsigned mid_index = desc->method_indices_by_name[mid];
const char *mid_name = desc->methods[mid_index].name;
int rv = strcmp(mid_name, name);
if (rv == 0) return desc->methods + desc->method_indices_by_name[mid];
if (rv < 0) {
count = start + count - (mid + 1);
start = mid + 1;
} else {
count = mid - start;
}
}
if (count == 0) return NULL;
if (strcmp(desc->methods[desc->method_indices_by_name[start]].name, name) ==
0)
return desc->methods + desc->method_indices_by_name[start];
return NULL;
}
......@@ -798,76 +798,6 @@ uint32_t protobuf_c_version_number(void);
*/
#define PROTOBUF_C_MIN_COMPILER_VERSION 1000000
/**
* Look up a `ProtobufCEnumValue` from a `ProtobufCEnumDescriptor` by name.
*
* \param desc
* The `ProtobufCEnumDescriptor` object.
* \param name
* The `name` field from the corresponding `ProtobufCEnumValue` object to
* match.
* \return
* A `ProtobufCEnumValue` object.
* \retval NULL
* If not found or if the optimize_for = CODE_SIZE option was set.
*/
PROTOBUF_C__API
const ProtobufCEnumValue *protobuf_c_enum_descriptor_get_value_by_name(
const ProtobufCEnumDescriptor *desc, const char *name);
/**
* Look up a `ProtobufCEnumValue` from a `ProtobufCEnumDescriptor` by numeric
* value.
*
* \param desc
* The `ProtobufCEnumDescriptor` object.
* \param value
* The `value` field from the corresponding `ProtobufCEnumValue` object to
* match.
*
* \return
* A `ProtobufCEnumValue` object.
* \retval NULL
* If not found.
*/
PROTOBUF_C__API
const ProtobufCEnumValue *protobuf_c_enum_descriptor_get_value(
const ProtobufCEnumDescriptor *desc, int value);
/**
* Look up a `ProtobufCFieldDescriptor` from a `ProtobufCMessageDescriptor` by
* the name of the field.
*
* \param desc
* The `ProtobufCMessageDescriptor` object.
* \param name
* The name of the field.
* \return
* A `ProtobufCFieldDescriptor` object.
* \retval NULL
* If not found or if the optimize_for = CODE_SIZE option was set.
*/
PROTOBUF_C__API
const ProtobufCFieldDescriptor *protobuf_c_message_descriptor_get_field_by_name(
const ProtobufCMessageDescriptor *desc, const char *name);
/**
* Look up a `ProtobufCFieldDescriptor` from a `ProtobufCMessageDescriptor` by
* the tag value of the field.
*
* \param desc
* The `ProtobufCMessageDescriptor` object.
* \param value
* The tag value of the field.
* \return
* A `ProtobufCFieldDescriptor` object.
* \retval NULL
* If not found.
*/
PROTOBUF_C__API
const ProtobufCFieldDescriptor *protobuf_c_message_descriptor_get_field(
const ProtobufCMessageDescriptor *desc, unsigned value);
/**
* Determine the number of bytes required to store the serialised message.
*
......@@ -947,33 +877,6 @@ PROTOBUF_C__API
void protobuf_c_message_init(const ProtobufCMessageDescriptor *descriptor,
void *message);
/**
* Free a service.
*
* \param service
* The service object to free.
*/
PROTOBUF_C__API
void protobuf_c_service_destroy(ProtobufCService *service);
/**
* Look up a `ProtobufCMethodDescriptor` by name.
*
* \param desc
* Service descriptor.
* \param name
* Name of the method.
*
* \return
* A `ProtobufCMethodDescriptor` object.
* \retval NULL
* If not found or if the optimize_for = CODE_SIZE option was set.
*/
PROTOBUF_C__API
const ProtobufCMethodDescriptor *
protobuf_c_service_descriptor_get_method_by_name(
const ProtobufCServiceDescriptor *desc, const char *name);
/**
* Initialise a `ProtobufCBufferSimple` object.
*/
......@@ -1011,18 +914,6 @@ PROTOBUF_C__API
void protobuf_c_buffer_simple_append(ProtobufCBuffer *buffer, size_t len,
const unsigned char *data);
PROTOBUF_C__API
void protobuf_c_service_generated_init(
ProtobufCService *service, const ProtobufCServiceDescriptor *descriptor,
ProtobufCServiceDestroy destroy);
PROTOBUF_C__API
void protobuf_c_service_invoke_internal(ProtobufCService *service,
unsigned method_index,
const ProtobufCMessage *input,
ProtobufCClosure closure,
void *closure_data);
/**@}*/
PROTOBUF_C__END_DECLS
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
......@@ -19,7 +19,6 @@ limitations under the License. */
#include <vector>
#include "framework/op_kernel_type.h"
#include "framework/selected_rows.h"
#include "framework/tensor.h"
#include "framework/variable.h"
......
......@@ -27,7 +27,6 @@ limitations under the License. */
#include "framework/op_info.h"
#include "framework/op_kernel_type.h"
#include "framework/op_registry.h"
#include "framework/paddle_mobile_object.h"
#include "framework/program/block_desc.h"
#include "framework/program/program-optimize/node.h"
#include "framework/scope.h"
......@@ -52,7 +51,7 @@ static T *GetVarValue(const string &key, const VariableNameMap &var_map,
}
template <typename Dtype>
class OperatorBase : PaddleMobileObject {
class OperatorBase {
public:
/*
* @b op 基类的实例化方法, op 获取到了 输入、参数以及提前分配好的输出 tensor
......@@ -121,7 +120,7 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
* @b 所有kernel的父类
* */
template <typename Dtype, typename P>
class OpKernelBase : PaddleMobileObject {
class OpKernelBase {
public:
/*
* @b 所有kernel 需实现 Compute 方法
......@@ -139,7 +138,7 @@ class OpKernelBase : PaddleMobileObject {
std::shared_ptr<::paddle_mobile::framework::Scope> scope) \
: parent_cls<Dtype, T>(type, inputs, outputs, attrs, scope) {}
class FusionOpMatcher : PaddleMobileObject {
class FusionOpMatcher {
public:
FusionOpMatcher() {}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle_mobile_object.h"
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "stdio.h"
namespace paddle_mobile {
class PaddleMobileObject {
public:
virtual std::string ToString() {
char address[128] = {0};
sprintf(address, "%p", this);
return std::string(address);
}
private:
};
} // namespace paddle_mobile
......@@ -15,14 +15,13 @@ limitations under the License. */
#pragma once
#include "framework/framework.pb-c.h"
#include "framework/paddle_mobile_object.h"
#include "framework/program/op_desc.h"
#include "framework/program/var_desc.h"
namespace paddle_mobile {
namespace framework {
class BlockDesc : PaddleMobileObject {
class BlockDesc {
public:
friend class Node;
friend class ProgramOptimize;
......
......@@ -20,12 +20,11 @@ limitations under the License. */
#include "common/log.h"
#include "common/type_define.h"
#include "framework/framework.pb-c.h"
#include "framework/paddle_mobile_object.h"
namespace paddle_mobile {
namespace framework {
class OpDesc : PaddleMobileObject {
class OpDesc {
public:
friend class ProgramOptimize;
friend class FusionOpMatcher;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "fusion_op_register.h"
......@@ -21,13 +21,12 @@ limitations under the License. */
#include <vector>
#include "common/log.h"
#include "framework/paddle_mobile_object.h"
#include "framework/program/op_desc.h"
namespace paddle_mobile {
namespace framework {
class Node : PaddleMobileObject {
class Node {
friend class ProgramOptimize;
public:
......
......@@ -105,11 +105,7 @@ std::shared_ptr<ProgramDesc> ProgramOptimize::FushionOptimize(
}
}
// DLOG << "node: \n" << *begin_node;
std::vector<std::shared_ptr<framework::OpDesc>> op_descs;
// bool can_splite = begin_node->CanSplit({G_OP_TYPE_CONV,
// G_OP_TYPE_BATCHNORM, G_OP_TYPE_DEPTHWISE_CONV});
for (int m = 0; m < nodes.size(); ++m) {
auto &node = nodes[m];
op_descs.push_back(node->op_desc_);
......
......@@ -15,7 +15,6 @@ limitations under the License. */
#pragma once
#include "common/types.h"
#include "framework/paddle_mobile_object.h"
#include "framework/program/program_desc.h"
#include "framework/scope.h"
......@@ -23,7 +22,7 @@ namespace paddle_mobile {
namespace framework {
template <typename Dtype, Precision P = Precision::FP32>
class Program : PaddleMobileObject {
class Program {
public:
std::shared_ptr<ProgramDesc> originProgram;
std::shared_ptr<ProgramDesc> optimizeProgram;
......
......@@ -18,13 +18,12 @@ limitations under the License. */
#include "common/types.h"
#include "framework/framework.pb-c.h"
#include "framework/paddle_mobile_object.h"
#include "framework/program/block_desc.h"
namespace paddle_mobile {
namespace framework {
class ProgramDesc : PaddleMobileObject {
class ProgramDesc {
public:
friend class Node;
friend class ProgramOptimize;
......
......@@ -14,40 +14,14 @@ limitations under the License. */
#pragma once
#include <string>
#include "framework/framework.pb-c.h"
#include "framework/paddle_mobile_object.h"
#include "framework/program/tensor_desc.h"
namespace paddle_mobile {
namespace framework {
/*
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__BOOL = 0,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT16 = 1,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT32 = 2,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__INT64 = 3,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP16 = 4,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP32 = 5,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FP64 = 6,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__LOD_TENSOR = 7,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__SELECTED_ROWS = 8,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FEED_MINIBATCH = 9,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__FETCH_LIST = 10,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__STEP_SCOPES = 11,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__LOD_RANK_TABLE = 12,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__LOD_TENSOR_ARRAY = 13,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__PLACE_LIST = 14,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__READER = 15,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__CHANNEL = 16,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__RAW = 17,
PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__TUPLE = 18
*/
class VarDesc {
public:
VarDesc(const VarDesc &var_desc) {
......@@ -56,14 +30,6 @@ class VarDesc {
this->persistable_ = var_desc.persistable_;
this->tensor_desc_ = var_desc.tensor_desc_;
this->type_ = var_desc.type_;
/*
*
* std::string name_;
bool persistable_;
TensorDesc tensor_desc_;
VarType_Type type_;
VarType_Type data_type_;
* */
}
VarDesc(PaddleMobile__Framework__Proto__VarDesc *desc) {
type_ = (VarType_Type)desc->type->type;
......@@ -102,39 +68,6 @@ class VarDesc {
const TensorDesc &Tensor_desc() const { return tensor_desc_; }
// const proto::VarType::ChannelDesc &channel_desc() const {
// switch (desc_.type().type()) {
// case proto::VarType::CHANNEL:
// return desc_.type().channel();
// default:
// break;
// }
// }
// proto::VarType::Type GetDataType() const {
// switch (desc_.type().type()) {
// case proto::VarType::CHANNEL:
// return channel_desc().data_type();
// break;
// default:
// return tensor_desc().data_type();
// }
// }
// template <typename T>
// std::vector<T> RepeatedToVector(
// const google::protobuf::RepeatedField<T> &repeated_field) const {
// std::vector<T> ret;
// ret.reserve(repeated_field.size());
// std::copy(repeated_field.begin(), repeated_field.end(),
// std::back_inserter(ret));
// return ret;
// }
// std::vector<int64_t> GetShape() const {
// return this->RepeatedToVector(tensor_desc().dims());
// }
private:
std::string name_;
bool persistable_;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lod_tensor.h"
#include "tensor.h"
namespace paddle_mobile {
namespace framework {
class SelectedRows {
public:
SelectedRows(const std::vector<int64_t> &rows, const int64_t &height)
: rows_(rows), height_(height) {
value_.reset(new Tensor());
}
SelectedRows() {
height_ = 0;
value_.reset(new Tensor());
}
const Tensor &value() const { return *value_; }
Tensor *mutable_value() { return value_.get(); }
int64_t height() const { return height_; }
void set_height(int64_t height) { height_ = height; }
const std::vector<int64_t> &rows() const { return rows_; }
std::vector<int64_t> *mutable_rows() { return &rows_; }
void set_rows(const std::vector<int64_t> &rows) { rows_ = rows; }
/**
* get the index of id in rows
*/
int64_t index(int64_t id) const {
auto it = std::find(rows_.begin(), rows_.end(), id);
// PADDLE_ENFORCE(it != rows_.end(), "id should be in rows");
return static_cast<int64_t>(std::distance(rows_.begin(), it));
}
DDim GetCompleteDims() const {
std::vector<int64_t> dims = vectorize(value_->dims());
dims[0] = height_;
return make_ddim(dims);
}
private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9}
// here.
// SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled.
std::vector<int64_t> rows_;
std::unique_ptr<Tensor> value_{nullptr};
int64_t height_;
};
} // namespace framework
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "framework.pb.h"
#include "lod_tensor.h"
#include "selected_rows.h"
#include "variable.h"
namespace paddle_mobile {
namespace framework {
inline proto::VarType::Type ToVarType(std::type_index type) {
if (type.hash_code() == typeid(LoDTensor).hash_code()) {
return proto::VarType_Type_LOD_TENSOR;
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
return proto::VarType_Type_SELECTED_ROWS;
} else {
// PADDLE_THROW("ToVarType:Unsupported type %s",
// type.name());
}
}
} // namespace framework
} // namespace paddle_mobile
......@@ -20,13 +20,12 @@ limitations under the License. */
#include <typeindex>
#include <typeinfo>
#include "../common/variant.h"
#include "paddle_mobile_object.h"
namespace paddle_mobile {
namespace framework {
using std::string;
class Variable : public PaddleMobileObject {
class Variable {
public:
template <typename T>
const T *Get() const {
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "io.h"
#include <fstream>
#include <vector>
#include "common/log.h"
......@@ -30,16 +29,20 @@ limitations under the License. */
namespace paddle_mobile {
using framework::Variable;
void ReadBinaryFile(const std::string &filename, std::string *contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_MOBILE_ENFORCE(fin.is_open(), "open file: %s failed",
char *Get_binary_data(std::string filename) {
FILE *file = fopen(filename.c_str(), "rb");
PADDLE_MOBILE_ENFORCE(file != nullptr, "can't open file: %s ",
filename.c_str());
fin.seekg(0, std::ios::end);
contents->clear();
contents->resize(fin.tellg());
fin.seekg(0, std::ios::beg);
fin.read(&(contents->at(0)), contents->size());
fin.close();
fseek(file, 0, SEEK_END);
long size = ftell(file);
PADDLE_MOBILE_ENFORCE(size > 0, "size is too small");
rewind(file);
char *data = new char[size];
size_t bytes_read = fread(data, 1, size, file);
PADDLE_MOBILE_ENFORCE(bytes_read == size,
"read binary file bytes do not match with fseek");
fclose(file);
return data;
}
static size_t ReadBuffer(const char *file_name, uint8_t **out) {
......@@ -70,64 +73,47 @@ void Loader<Dtype, P>::LoadVar(framework::Variable *variable,
const framework::VarDesc &var_desc,
const std::string &file_path) {
auto tensor = variable->GetMutable<framework::LoDTensor>();
std::ifstream is(file_path);
PADDLE_MOBILE_ENFORCE(is.is_open(), "open file: %s failed",
file_path.c_str());
std::fpos<mbstate_t> pos;
pos = is.tellg(); // save current position
is.seekg(0, std::ios::end);
is.seekg(pos); // restore saved position
char *data = Get_binary_data(file_path);
// 1. version
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
uint32_t version = *(uint32_t *)data;
data += sizeof(uint32_t);
// 2 Lod information
uint64_t lod_level;
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
uint32_t lod_level = *(uint64_t *)data;
data += sizeof(uint64_t);
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
uint32_t size = *(uint64_t *)data;
data += sizeof(uint64_t);
std::vector<size_t> tmp(size / sizeof(size_t));
is.read(reinterpret_cast<char *>(tmp.data()),
static_cast<std::streamsize>(size));
for (auto j : tmp) {
LOG(kLOG_DEBUG1) << " lod - " << j;
for (int k = 0; k < tmp.size(); ++k) {
tmp[k] = *(size_t *)data;
}
lod[i] = tmp;
}
// 3. tensor version
uint32_t tensor_version;
is.read(reinterpret_cast<char *>(&tensor_version), sizeof(tensor_version));
uint32_t tensor_version = *(uint32_t *)data;
data += sizeof(uint32_t);
// 4. tensor desc
int32_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
uint32_t size = *(int32_t *)data;
data += sizeof(int32_t);
std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char *>(buf.get()), size);
for (int m = 0; m < size; ++m) {
buf.get()[m] = data[m];
}
const framework::TensorDesc &desc = var_desc.Tensor_desc();
PaddleMobile__Framework__Proto__VarType__TensorDesc *tensor_desc = NULL;
// void *v;
// PaddleMobile__Framework__Proto__VarType__TensorDesc_Closure()(tensor_desc,
// buf.get());
// DLOG << "PaddleMobile__Framework__Proto__VarType__TensorDesc_Closure- " <<
// tensor_desc;
// framework::TensorDesc &tensor_desc = variable->
// PaddleMobile__Framework__Proto__ProgramDesc *c_program;
// uint8_t *proto_buf = NULL;
// size_t read_size = ReadBuffer(file_path.c_str(), &proto_buf);
// c_program = paddle_mobile__framework__proto__program_desc__unpack(NULL,
// read_size, buf);
// paddle_mobile__framework__proto__var_type__tensor_desc__init()
int memory_size = 1;
for (auto l : desc.Dims()) {
......@@ -162,8 +148,11 @@ void Loader<Dtype, P>::LoadVar(framework::Variable *variable,
break;
}
is.read(static_cast<char *>(memory), memory_size * type_size);
is.close();
for (int n = 0; n < memory_size * type_size; ++n) {
static_cast<char *>(memory)[n] = data[n];
}
delete data;
}
template <typename Dtype, Precision P>
......@@ -276,29 +265,34 @@ template <typename Dtype, Precision P>
void Executor<Dtype, P>::LoadMemory(const framework::VarDesc var_desc,
framework::LoDTensor *tensor,
const std::string &file_path) {
std::ifstream is(file_path);
PADDLE_MOBILE_ENFORCE(is.is_open(), "open file: %s failed",
file_path.c_str());
std::fpos<mbstate_t> pos;
pos = is.tellg(); // save current position
is.seekg(0, std::ios::end);
is.seekg(pos); // restore saved position
char *origin_data = Get_binary_data(file_path);
char *data = origin_data;
// 1. version
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
uint32_t version = *(uint32_t *)data;
data += sizeof(uint32_t);
DLOG << "version: " << version;
// 2 Lod information
uint64_t lod_level;
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
uint64_t lod_level = *(uint64_t *)data;
data += sizeof(uint64_t);
DLOG << "lod_level: " << lod_level;
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
uint64_t size = *(uint64_t *)data;
data += sizeof(uint64_t);
DLOG << "lod size: " << i << size;
std::vector<size_t> tmp(size / sizeof(size_t));
is.read(reinterpret_cast<char *>(tmp.data()),
static_cast<std::streamsize>(size));
for (int k = 0; k < tmp.size(); ++k) {
tmp[k] = *(size_t *)data;
DLOG << "tmp[k]: " << k << *(size_t *)data;
data += sizeof(size_t);
}
for (auto j : tmp) {
LOG(kLOG_DEBUG1) << " lod - " << j;
}
......@@ -306,17 +300,22 @@ void Executor<Dtype, P>::LoadMemory(const framework::VarDesc var_desc,
}
// 3. tensor version
uint32_t tensor_version;
is.read(reinterpret_cast<char *>(&tensor_version), sizeof(tensor_version));
uint32_t tensor_version = *(uint32_t *)data;
data += sizeof(uint32_t);
DLOG << "tensor_version: " << tensor_version;
// 4. tensor desc
int32_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
int32_t size = *(int32_t *)data;
data += sizeof(int32_t);
DLOG << "tensor desc size: " << size;
std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char *>(buf.get()), size);
for (int m = 0; m < size; ++m) {
buf.get()[m] = data[m];
}
data += (sizeof(char) * size);
const framework::TensorDesc &desc = var_desc.Tensor_desc();
int memory_size = 1;
for (auto l : desc.Dims()) {
memory_size *= l;
......@@ -332,6 +331,7 @@ void Executor<Dtype, P>::LoadMemory(const framework::VarDesc var_desc,
break;
case framework::VARTYPE_TYPE_FP32:
type_size = 4;
DLOG << " type size: " << type_size;
memory = tensor->mutable_data<float>();
break;
case framework::VARTYPE_TYPE_FP64:
......@@ -350,8 +350,11 @@ void Executor<Dtype, P>::LoadMemory(const framework::VarDesc var_desc,
break;
}
is.read(static_cast<char *>(memory), memory_size * type_size);
is.close();
for (int n = 0; n < memory_size * type_size; ++n) {
static_cast<char *>(memory)[n] = data[n];
}
delete origin_data;
}
template <typename Dtype, Precision P>
......
......@@ -22,14 +22,13 @@ limitations under the License. */
#include "common/types.h"
#include "framework/lod_tensor.h"
#include "framework/operator.h"
#include "framework/paddle_mobile_object.h"
#include "framework/program/program.h"
#include "framework/tensor.h"
namespace paddle_mobile {
template <typename Dtype, Precision P = Precision::FP32>
class Loader : PaddleMobileObject {
class Loader {
public:
const framework::Program<Dtype, P> Load(const std::string &dirname,
bool optimize = false);
......
......@@ -34,7 +34,7 @@ using framework::Tensor;
using std::string;
using std::vector;
class OpParam : PaddleMobileObject {
class OpParam {
protected:
template <typename T>
static T *InputFrom(const VariableNameMap &inputs, const Scope &scope) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册