提交 b8831606 编写于 作者: L liuruilong

format files

上级 30d22c55
此差异已折叠。
......@@ -202,34 +202,34 @@ size_t foo__bar__baz_bah__pack_to_buffer
#include <stdint.h>
#ifdef __cplusplus
# define PROTOBUF_C__BEGIN_DECLS extern "C" {
# define PROTOBUF_C__END_DECLS }
#define PROTOBUF_C__BEGIN_DECLS extern "C" {
#define PROTOBUF_C__END_DECLS }
#else
# define PROTOBUF_C__BEGIN_DECLS
# define PROTOBUF_C__END_DECLS
#define PROTOBUF_C__BEGIN_DECLS
#define PROTOBUF_C__END_DECLS
#endif
PROTOBUF_C__BEGIN_DECLS
#if defined(_WIN32) && defined(PROTOBUF_C_USE_SHARED_LIB)
# ifdef PROTOBUF_C_EXPORT
# define PROTOBUF_C__API __declspec(dllexport)
# else
# define PROTOBUF_C__API __declspec(dllimport)
# endif
#ifdef PROTOBUF_C_EXPORT
#define PROTOBUF_C__API __declspec(dllexport)
#else
# define PROTOBUF_C__API
#define PROTOBUF_C__API __declspec(dllimport)
#endif
#else
#define PROTOBUF_C__API
#endif
#if !defined(PROTOBUF_C__NO_DEPRECATED) && \
((__GNUC__ > 3) || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1))
# define PROTOBUF_C__DEPRECATED __attribute__((__deprecated__))
#define PROTOBUF_C__DEPRECATED __attribute__((__deprecated__))
#else
# define PROTOBUF_C__DEPRECATED
#define PROTOBUF_C__DEPRECATED
#endif
#ifndef PROTOBUF_C__FORCE_ENUM_TO_BE_INT_SIZE
#define PROTOBUF_C__FORCE_ENUM_TO_BE_INT_SIZE(enum_name) \
#define PROTOBUF_C__FORCE_ENUM_TO_BE_INT_SIZE(enum_name) \
, _##enum_name##_IS_INT_SIZE = INT_MAX
#endif
......@@ -441,9 +441,7 @@ protobuf_c_message_pack_to_buffer(&message, &tmp);
*/
struct ProtobufCBuffer {
/** Append function. Consumes the `len` bytes stored at `data`. */
void (*append)(ProtobufCBuffer *buffer,
size_t len,
const uint8_t *data);
void (*append)(ProtobufCBuffer *buffer, size_t len, const uint8_t *data);
};
/**
......@@ -733,10 +731,8 @@ struct ProtobufCService {
/** Service descriptor. */
const ProtobufCServiceDescriptor *descriptor;
/** Function to invoke the service. */
void (*invoke)(ProtobufCService *service,
unsigned method_index,
const ProtobufCMessage *input,
ProtobufCClosure closure,
void (*invoke)(ProtobufCService *service, unsigned method_index,
const ProtobufCMessage *input, ProtobufCClosure closure,
void *closure_data);
/** Function to destroy the service. */
void (*destroy)(ProtobufCService *service);
......@@ -772,8 +768,7 @@ struct ProtobufCServiceDescriptor {
* \return A string containing the version number of protobuf-c.
*/
PROTOBUF_C__API
const char *
protobuf_c_version(void);
const char *protobuf_c_version(void);
/**
* Get the version of the protobuf-c library. Note that this is the version of
......@@ -783,8 +778,7 @@ protobuf_c_version(void);
* protobuf-c, represented in base-10 as (MAJOR*1E6) + (MINOR*1E3) + PATCH.
*/
PROTOBUF_C__API
uint32_t
protobuf_c_version_number(void);
uint32_t protobuf_c_version_number(void);
/**
* The version of the protobuf-c headers, represented as a string using the same
......@@ -818,10 +812,8 @@ protobuf_c_version_number(void);
* If not found or if the optimize_for = CODE_SIZE option was set.
*/
PROTOBUF_C__API
const ProtobufCEnumValue *
protobuf_c_enum_descriptor_get_value_by_name(
const ProtobufCEnumDescriptor *desc,
const char *name);
const ProtobufCEnumValue *protobuf_c_enum_descriptor_get_value_by_name(
const ProtobufCEnumDescriptor *desc, const char *name);
/**
* Look up a `ProtobufCEnumValue` from a `ProtobufCEnumDescriptor` by numeric
......@@ -839,10 +831,8 @@ protobuf_c_enum_descriptor_get_value_by_name(
* If not found.
*/
PROTOBUF_C__API
const ProtobufCEnumValue *
protobuf_c_enum_descriptor_get_value(
const ProtobufCEnumDescriptor *desc,
int value);
const ProtobufCEnumValue *protobuf_c_enum_descriptor_get_value(
const ProtobufCEnumDescriptor *desc, int value);
/**
* Look up a `ProtobufCFieldDescriptor` from a `ProtobufCMessageDescriptor` by
......@@ -858,10 +848,8 @@ protobuf_c_enum_descriptor_get_value(
* If not found or if the optimize_for = CODE_SIZE option was set.
*/
PROTOBUF_C__API
const ProtobufCFieldDescriptor *
protobuf_c_message_descriptor_get_field_by_name(
const ProtobufCMessageDescriptor *desc,
const char *name);
const ProtobufCFieldDescriptor *protobuf_c_message_descriptor_get_field_by_name(
const ProtobufCMessageDescriptor *desc, const char *name);
/**
* Look up a `ProtobufCFieldDescriptor` from a `ProtobufCMessageDescriptor` by
......@@ -877,10 +865,8 @@ protobuf_c_message_descriptor_get_field_by_name(
* If not found.
*/
PROTOBUF_C__API
const ProtobufCFieldDescriptor *
protobuf_c_message_descriptor_get_field(
const ProtobufCMessageDescriptor *desc,
unsigned value);
const ProtobufCFieldDescriptor *protobuf_c_message_descriptor_get_field(
const ProtobufCMessageDescriptor *desc, unsigned value);
/**
* Determine the number of bytes required to store the serialised message.
......@@ -891,9 +877,7 @@ protobuf_c_message_descriptor_get_field(
* Number of bytes.
*/
PROTOBUF_C__API
size_t
protobuf_c_message_get_packed_size(const ProtobufCMessage *message);
size_t protobuf_c_message_get_packed_size(const ProtobufCMessage *message);
/**
* Unpack a serialised message into an in-memory representation.
......@@ -913,12 +897,9 @@ protobuf_c_message_get_packed_size(const ProtobufCMessage *message);
* If an error occurred during unpacking.
*/
PROTOBUF_C__API
ProtobufCMessage *
protobuf_c_message_unpack(
const ProtobufCMessageDescriptor *descriptor,
ProtobufCAllocator *allocator,
size_t len,
const uint8_t *data);
ProtobufCMessage *protobuf_c_message_unpack(
const ProtobufCMessageDescriptor *descriptor, ProtobufCAllocator *allocator,
size_t len, const uint8_t *data);
/**
* Free an unpacked message object.
......@@ -933,9 +914,7 @@ protobuf_c_message_unpack(
* specify the default allocator.
*/
PROTOBUF_C__API
void
protobuf_c_message_free_unpacked(
ProtobufCMessage *message,
void protobuf_c_message_free_unpacked(ProtobufCMessage *message,
ProtobufCAllocator *allocator);
/**
......@@ -950,11 +929,11 @@ protobuf_c_message_free_unpacked(
* Message is invalid.
*/
PROTOBUF_C__API
protobuf_c_boolean
protobuf_c_message_check(const ProtobufCMessage *);
protobuf_c_boolean protobuf_c_message_check(const ProtobufCMessage *);
/** Message initialiser. */
#define PROTOBUF_C_MESSAGE_INIT(descriptor) { descriptor, 0, NULL }
#define PROTOBUF_C_MESSAGE_INIT(descriptor) \
{ descriptor, 0, NULL }
/**
* Initialise a message object from a message descriptor.
......@@ -965,9 +944,7 @@ protobuf_c_message_check(const ProtobufCMessage *);
* Allocated block of memory of size `descriptor->sizeof_message`.
*/
PROTOBUF_C__API
void
protobuf_c_message_init(
const ProtobufCMessageDescriptor *descriptor,
void protobuf_c_message_init(const ProtobufCMessageDescriptor *descriptor,
void *message);
/**
......@@ -977,8 +954,7 @@ protobuf_c_message_init(
* The service object to free.
*/
PROTOBUF_C__API
void
protobuf_c_service_destroy(ProtobufCService *service);
void protobuf_c_service_destroy(ProtobufCService *service);
/**
* Look up a `ProtobufCMethodDescriptor` by name.
......@@ -996,36 +972,29 @@ protobuf_c_service_destroy(ProtobufCService *service);
PROTOBUF_C__API
const ProtobufCMethodDescriptor *
protobuf_c_service_descriptor_get_method_by_name(
const ProtobufCServiceDescriptor *desc,
const char *name);
const ProtobufCServiceDescriptor *desc, const char *name);
/**
* Initialise a `ProtobufCBufferSimple` object.
*/
#define PROTOBUF_C_BUFFER_SIMPLE_INIT(array_of_bytes) \
{ \
{ protobuf_c_buffer_simple_append }, \
sizeof(array_of_bytes), \
0, \
(array_of_bytes), \
0, \
NULL \
}
{ \
{protobuf_c_buffer_simple_append}, sizeof(array_of_bytes), 0, \
(array_of_bytes), 0, NULL \
}
/**
* Clear a `ProtobufCBufferSimple` object, freeing any allocated memory.
*/
#define PROTOBUF_C_BUFFER_SIMPLE_CLEAR(simp_buf) \
do { \
do { \
if ((simp_buf)->must_free_data) { \
if ((simp_buf)->allocator != NULL) \
(simp_buf)->allocator->free( \
(simp_buf)->allocator, \
(simp_buf)->data); \
(simp_buf)->allocator->free((simp_buf)->allocator, (simp_buf)->data); \
else \
free((simp_buf)->data); \
} \
} while (0)
} while (0)
/**
* The `append` method for `ProtobufCBufferSimple`.
......@@ -1039,23 +1008,16 @@ do { \
* Data to append.
*/
PROTOBUF_C__API
void
protobuf_c_buffer_simple_append(
ProtobufCBuffer *buffer,
size_t len,
void protobuf_c_buffer_simple_append(ProtobufCBuffer *buffer, size_t len,
const unsigned char *data);
PROTOBUF_C__API
void
protobuf_c_service_generated_init(
ProtobufCService *service,
const ProtobufCServiceDescriptor *descriptor,
void protobuf_c_service_generated_init(
ProtobufCService *service, const ProtobufCServiceDescriptor *descriptor,
ProtobufCServiceDestroy destroy);
PROTOBUF_C__API
void
protobuf_c_service_invoke_internal(
ProtobufCService *service,
void protobuf_c_service_invoke_internal(ProtobufCService *service,
unsigned method_index,
const ProtobufCMessage *input,
ProtobufCClosure closure,
......
......@@ -16,8 +16,8 @@ limitations under the License. */
#include <map>
#include <string>
#include <vector>
#include <unordered_set>
#include <vector>
#include "framework/attribute.h"
#include "framework/scope.h"
......
......@@ -15,8 +15,8 @@ limitations under the License. */
#pragma once
#include <unordered_map>
#include "common/log.h"
#include "common/enforce.h"
#include "common/log.h"
#include "common/variant.h"
#include "framework/framework.pb-c.h"
......@@ -27,7 +27,6 @@ class BlockDesc;
class Attribute {
public:
/*
* PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__INT = 0,
PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__FLOAT = 1,
......@@ -42,7 +41,8 @@ class Attribute {
PROTOBUF_C__FORCE_ENUM_TO_BE_INT_SIZE(PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE)
*
* */
static Attribute GetAttrValue(PaddleMobile__Framework__Proto__OpDesc__Attr *attr_desc) {
static Attribute GetAttrValue(
PaddleMobile__Framework__Proto__OpDesc__Attr *attr_desc) {
// std::cout << "begin get attr value" << std::endl;
Attribute attr;
switch (attr_desc->type) {
......
此差异已折叠。
此差异已折叠。
......@@ -33,17 +33,18 @@ std::vector<std::shared_ptr<OpDesc>> BlockDesc::Ops() const {
return res;
}
BlockDesc::BlockDesc(PaddleMobile__Framework__Proto__BlockDesc *desc): index_(desc->idx), parent_index_(desc->idx) {
BlockDesc::BlockDesc(PaddleMobile__Framework__Proto__BlockDesc *desc)
: index_(desc->idx), parent_index_(desc->idx) {
for (int i = 0; i < desc->n_vars; ++i) {
PaddleMobile__Framework__Proto__VarDesc *var_desc = desc->vars[i];
vars_[std::string(var_desc->name)] = std::shared_ptr<VarDesc>(new VarDesc(var_desc));
vars_[std::string(var_desc->name)] =
std::shared_ptr<VarDesc>(new VarDesc(var_desc));
}
for (int j = 0; j < desc->n_ops; ++j) {
PaddleMobile__Framework__Proto__OpDesc *op_desc = desc->ops[j];
ops_.emplace_back(new framework::OpDesc(op_desc));
}
}
} // namespace framework
......
......@@ -15,9 +15,9 @@ limitations under the License. */
#pragma once
#include "framework/framework.pb-c.h"
#include "framework/paddle_mobile_object.h"
#include "framework/program/op_desc.h"
#include "framework/program/var_desc.h"
#include "framework/paddle_mobile_object.h"
namespace paddle_mobile {
namespace framework {
......
......@@ -15,8 +15,8 @@ limitations under the License. */
#include <string>
#include <vector>
#include "program_desc.h"
#include "framework/program/tensor_desc.h"
#include "program_desc.h"
namespace paddle_mobile {
namespace framework {
......@@ -70,7 +70,6 @@ void ProgramDesc::Description(std::string header) {
}
}
}
}
#endif
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
//
// Created by liuRuiLong on 2018/5/26.
//
......
......@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle_mobile {
namespace framework {
enum VarType_Type{
enum VarType_Type {
VARTYPE_TYPE_BOOL = 0,
VARTYPE_TYPE_INT16 = 1,
VARTYPE_TYPE_INT32 = 2,
......@@ -59,17 +59,13 @@ class TensorDesc {
data_type_ = (VarType_Type)desc->data_type;
}
std::vector<int64_t> Dims() const {
return dims_;
};
VarType_Type DataType() const {
return data_type_;
}
std::vector<int64_t> Dims() const { return dims_; };
VarType_Type DataType() const { return data_type_; }
private:
std::vector<int64_t> dims_;
VarType_Type data_type_;
};
}
}
} // namespace framework
} // namespace paddle_mobile
......@@ -16,8 +16,5 @@ limitations under the License. */
namespace paddle_mobile {
namespace framework {
} // namespace framework
namespace framework {} // namespace framework
} // namespace paddle_mobile
......@@ -15,8 +15,8 @@ limitations under the License. */
#pragma once
#include "framework/framework.pb-c.h"
#include "framework/program/tensor_desc.h"
#include "framework/paddle_mobile_object.h"
#include "framework/program/tensor_desc.h"
namespace paddle_mobile {
namespace framework {
......@@ -48,7 +48,6 @@ PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__TYPE__BOOL = 0,
*/
class VarDesc {
public:
VarDesc(const VarDesc &var_desc) {
......@@ -65,7 +64,6 @@ class VarDesc {
VarType_Type type_;
VarType_Type data_type_;
* */
}
VarDesc(PaddleMobile__Framework__Proto__VarDesc *desc) {
type_ = (VarType_Type)desc->type->type;
......@@ -94,9 +92,7 @@ class VarDesc {
default:
data_type_ = tensor_desc_.DataType();
break;
}
}
std::string Name() const { return name_; }
......@@ -104,42 +100,40 @@ class VarDesc {
bool Persistable() const { return persistable_; }
const TensorDesc &Tensor_desc() const {
return tensor_desc_;
}
const TensorDesc &Tensor_desc() const { return tensor_desc_; }
// const proto::VarType::ChannelDesc &channel_desc() const {
// switch (desc_.type().type()) {
// case proto::VarType::CHANNEL:
// return desc_.type().channel();
// default:
// break;
// }
// }
// proto::VarType::Type GetDataType() const {
// switch (desc_.type().type()) {
// case proto::VarType::CHANNEL:
// return channel_desc().data_type();
// break;
// default:
// return tensor_desc().data_type();
// }
// }
// template <typename T>
// std::vector<T> RepeatedToVector(
// const google::protobuf::RepeatedField<T> &repeated_field) const {
// std::vector<T> ret;
// ret.reserve(repeated_field.size());
// std::copy(repeated_field.begin(), repeated_field.end(),
// std::back_inserter(ret));
// return ret;
// }
// std::vector<int64_t> GetShape() const {
// return this->RepeatedToVector(tensor_desc().dims());
// }
// switch (desc_.type().type()) {
// case proto::VarType::CHANNEL:
// return desc_.type().channel();
// default:
// break;
// }
// }
// proto::VarType::Type GetDataType() const {
// switch (desc_.type().type()) {
// case proto::VarType::CHANNEL:
// return channel_desc().data_type();
// break;
// default:
// return tensor_desc().data_type();
// }
// }
// template <typename T>
// std::vector<T> RepeatedToVector(
// const google::protobuf::RepeatedField<T> &repeated_field) const {
// std::vector<T> ret;
// ret.reserve(repeated_field.size());
// std::copy(repeated_field.begin(), repeated_field.end(),
// std::back_inserter(ret));
// return ret;
// }
// std::vector<int64_t> GetShape() const {
// return this->RepeatedToVector(tensor_desc().dims());
// }
private:
std::string name_;
......@@ -147,7 +141,6 @@ class VarDesc {
TensorDesc tensor_desc_;
VarType_Type type_;
VarType_Type data_type_;
};
} // namespace framework
......
......@@ -16,15 +16,15 @@ limitations under the License. */
#include <fstream>
#include <vector>
#include "common/log.h"
#include "common/enforce.h"
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/operator.h"
#include "framework/lod_tensor.h"
#include "common/log.h"
#include "framework/framework.pb-c.h"
#include "framework/program/var_desc.h"
#include "framework/lod_tensor.h"
#include "framework/operator.h"
#include "framework/program/program_desc.h"
#include "framework/program/var_desc.h"
#include "framework/scope.h"
#include "framework/tensor.h"
namespace paddle_mobile {
......@@ -40,7 +40,7 @@ void ReadBinaryFile(const std::string &filename, std::string *contents) {
fin.close();
}
static size_t ReadBuffer (const char *file_name, uint8_t **out) {
static size_t ReadBuffer(const char *file_name, uint8_t **out) {
printf("%s \n", file_name);
FILE *fp;
fp = fopen(file_name, "rb");
......@@ -56,7 +56,7 @@ static size_t ReadBuffer (const char *file_name, uint8_t **out) {
size_t cur_len = 0;
size_t nread;
while ((nread=fread(*out + cur_len, 1, size - cur_len, fp)) != 0) {
while ((nread = fread(*out + cur_len, 1, size - cur_len, fp)) != 0) {
cur_len += nread;
}
fclose(fp);
......@@ -64,7 +64,8 @@ static size_t ReadBuffer (const char *file_name, uint8_t **out) {
}
template <typename Dtype, Precision P>
void Loader<Dtype, P>::LoadVar(framework::Variable *variable, const framework::VarDesc &var_desc,
void Loader<Dtype, P>::LoadVar(framework::Variable *variable,
const framework::VarDesc &var_desc,
const std::string &file_path) {
auto tensor = variable->GetMutable<framework::LoDTensor>();
std::ifstream is(file_path);
......@@ -109,22 +110,22 @@ void Loader<Dtype, P>::LoadVar(framework::Variable *variable, const framework::V
const framework::TensorDesc &desc = var_desc.Tensor_desc();
PaddleMobile__Framework__Proto__VarType__TensorDesc *tensor_desc = NULL;
// void *v;
// PaddleMobile__Framework__Proto__VarType__TensorDesc_Closure()(tensor_desc, buf.get());
// DLOG << "PaddleMobile__Framework__Proto__VarType__TensorDesc_Closure- " << tensor_desc;
// void *v;
// PaddleMobile__Framework__Proto__VarType__TensorDesc_Closure()(tensor_desc,
// buf.get());
// DLOG << "PaddleMobile__Framework__Proto__VarType__TensorDesc_Closure- " <<
// tensor_desc;
// framework::TensorDesc &tensor_desc = variable->
// PaddleMobile__Framework__Proto__ProgramDesc *c_program;
// uint8_t *proto_buf = NULL;
// size_t read_size = ReadBuffer(file_path.c_str(), &proto_buf);
// c_program = paddle_mobile__framework__proto__program_desc__unpack(NULL, read_size, buf);
// paddle_mobile__framework__proto__var_type__tensor_desc__init()
// framework::TensorDesc &tensor_desc = variable->
// PaddleMobile__Framework__Proto__ProgramDesc *c_program;
// uint8_t *proto_buf = NULL;
// size_t read_size = ReadBuffer(file_path.c_str(), &proto_buf);
// c_program = paddle_mobile__framework__proto__program_desc__unpack(NULL,
// read_size, buf);
// paddle_mobile__framework__proto__var_type__tensor_desc__init()
int memory_size = 1;
for (auto l : desc.Dims()) {
......@@ -173,7 +174,8 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
PADDLE_MOBILE_ENFORCE(buf != NULL, "read from __model__ is null");
c_program = paddle_mobile__framework__proto__program_desc__unpack(NULL, read_size, buf);
c_program = paddle_mobile__framework__proto__program_desc__unpack(
NULL, read_size, buf);
PADDLE_MOBILE_ENFORCE(c_program != NULL, "program is null");
......@@ -194,14 +196,14 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
for (const auto &block : originProgramDesc->Blocks()) {
for (int i = 0; i < block->Vars().size(); ++i) {
std::shared_ptr<framework::VarDesc> var_desc = block->Vars()[i];
// DLOG << "var name-- " << var_desc->Name();
// DLOG << "var name-- " << var_desc->Name();
auto var = scope->Var(var_desc->Name());
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
if (var_desc->Persistable() &&
var_desc->Type() != framework::VARTYPE_TYPE_FEED_MINIBATCH &&
var_desc->Type() != framework::VARTYPE_TYPE_FETCH_LIST) {
// DLOG << "to load var ";
// DLOG << "to load var ";
LoadVar(var, *var_desc, dirname + "/" + var_desc->Name());
}
......@@ -247,7 +249,8 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p) : program_(p) {
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::LoadMemory(const framework::VarDesc var_desc, framework::LoDTensor *tensor,
void Executor<Dtype, P>::LoadMemory(const framework::VarDesc var_desc,
framework::LoDTensor *tensor,
const std::string &file_path) {
std::ifstream is(file_path);
PADDLE_MOBILE_ENFORCE(is.is_open(), "open file: %s failed",
......@@ -290,7 +293,6 @@ void Executor<Dtype, P>::LoadMemory(const framework::VarDesc var_desc, framework
const framework::TensorDesc &desc = var_desc.Tensor_desc();
int memory_size = 1;
for (auto l : desc.Dims()) {
memory_size *= l;
......@@ -335,7 +337,8 @@ void Executor<Dtype, P>::InitMemory() {
auto var = program_.scope->Var(var_desc->Name());
if (var_desc->Persistable()) {
auto tensor = var->template GetMutable<framework::LoDTensor>();
LoadMemory(*var_desc, tensor, program_.model_path + "/" + var_desc->Name());
LoadMemory(*var_desc, tensor,
program_.model_path + "/" + var_desc->Name());
} else {
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
auto tensor = var->template GetMutable<framework::Tensor>();
......
......@@ -33,7 +33,9 @@ class Loader : PaddleMobileObject {
const framework::Program<Dtype, P> Load(const std::string &dirname);
private:
void LoadVar(framework::Variable *variable, const framework::VarDesc &var_desc, const std::string &file_path);
void LoadVar(framework::Variable *variable,
const framework::VarDesc &var_desc,
const std::string &file_path);
};
template <typename Dtype, Precision P = Precision::FP32>
......@@ -52,7 +54,8 @@ class Executor {
protected:
void InitMemory();
void LoadMemory(const framework::VarDesc var_desc, framework::LoDTensor *tensor, const std::string &file_path);
void LoadMemory(const framework::VarDesc var_desc,
framework::LoDTensor *tensor, const std::string &file_path);
framework::Program<Dtype> program_;
std::shared_ptr<framework::ProgramDesc> to_predict_program_;
void predict(const framework::Tensor &t, int block_id);
......
......@@ -116,8 +116,7 @@ inline std::string DataTypeToString(const VarType_Type type) {
}
}
inline std::ostream &operator<<(std::ostream &out,
const VarType_Type &type) {
inline std::ostream &operator<<(std::ostream &out, const VarType_Type &type) {
out << DataTypeToString(type);
return out;
}
......
......@@ -58,7 +58,6 @@ class Executor4Test : public Executor<DeviceType> {
for (std::shared_ptr<BlockDesc> block_desc : blocks) {
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
for (std::shared_ptr<OpDesc> op : ops) {
if (op->Type() == op_type) {
/// test first meeting op in program
std::shared_ptr<paddle_mobile::framework::OperatorBase<DeviceType>>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册