提交 cec61f3d 编写于 作者: L liuruilong

remove third-party folder

上级 8a3a01ee
...@@ -27,29 +27,12 @@ include_directories(src/) ...@@ -27,29 +27,12 @@ include_directories(src/)
# INSTALL_COMMAND "make" "PREFIX=${CMAKE_BINARY_DIR}/" "install" # INSTALL_COMMAND "make" "PREFIX=${CMAKE_BINARY_DIR}/" "install"
# ) # )
#set_target_properties(openblas_proj PROPERTIES EXCLUDE_FROM_ALL 1) #set_target_properties(openblas_proj PROPERTIES EXCLUDE_FROM_ALL 1)
# link protobuf
include_directories(third-party/protobuf/include)
include_directories(third-party/protobuf-c-decoder/include)
if (ANDROID)
link_directories(third-party/protobuf/armeabi-v7a)
else()
# link openblas
link_directories(third-party/protobuf/lib)
link_directories(third-party/protobuf-c-decoder/lib)
endif ()
#add_dependencies(paddle-mobile openblas_proj) #add_dependencies(paddle-mobile openblas_proj)
# gen static # gen static
ADD_LIBRARY(paddle-mobile SHARED ${PADDLE_MOBILE_CC} ${PADDLE_MOBILE_H}) ADD_LIBRARY(paddle-mobile SHARED ${PADDLE_MOBILE_CC} ${PADDLE_MOBILE_H})
if (ANDROID)
# openblas.a need log lib
target_link_libraries(paddle-mobile protobuf-lite)
else()
target_link_libraries(paddle-mobile protobuf-lite)
target_link_libraries(paddle-mobile protobuf-c-decoder)
endif ()
#add_dependencies(paddle-mobile openblas_proj) #add_dependencies(paddle-mobile openblas_proj)
add_subdirectory(test) add_subdirectory(test)
此差异已折叠。
此差异已折叠。
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_set>
#include "framework/attribute.h" #include "framework/attribute.h"
#include "framework/scope.h" #include "framework/scope.h"
......
...@@ -14,10 +14,10 @@ limitations under the License. */ ...@@ -14,10 +14,10 @@ limitations under the License. */
#pragma once #pragma once
#include <unordered_map>
#include "common/log.h" #include "common/log.h"
#include "common/enforce.h" #include "common/enforce.h"
#include "common/variant.h" #include "common/variant.h"
#include "framework/framework.pb.h"
#include "framework/framework.pb-c.h" #include "framework/framework.pb-c.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -27,69 +27,6 @@ class BlockDesc; ...@@ -27,69 +27,6 @@ class BlockDesc;
class Attribute { class Attribute {
public: public:
static Attribute GetAttrValue(const proto::OpDesc::Attr &attr_desc) {
// std::cout << "begin get attr value" << std::endl;
Attribute attr;
switch (attr_desc.type()) {
case proto::AttrType::BOOLEAN: {
attr.Set<bool>(attr_desc.b());
break;
}
case proto::AttrType::INT: {
attr.Set<int>(attr_desc.i());
break;
}
case proto::AttrType::FLOAT: {
attr.Set<float>(attr_desc.f());
break;
}
case proto::AttrType::STRING: {
attr.Set<std::string>(attr_desc.s());
break;
}
case proto::AttrType::BOOLEANS: {
std::vector<bool> val(attr_desc.bools_size());
for (int i = 0; i < attr_desc.bools_size(); ++i) {
val[i] = attr_desc.bools(i);
}
attr.Set<std::vector<bool>>(val);
break;
}
case proto::AttrType::INTS: {
std::vector<int> val(attr_desc.ints_size());
for (int i = 0; i < attr_desc.ints_size(); ++i) {
val[i] = attr_desc.ints(i);
}
attr.Set<std::vector<int>>(val);
break;
}
case proto::AttrType::FLOATS: {
std::vector<float> val(attr_desc.floats_size());
for (int i = 0; i < attr_desc.floats_size(); ++i) {
val[i] = attr_desc.floats(i);
}
attr.Set<std::vector<float>>(val);
break;
}
case proto::AttrType::STRINGS: {
std::vector<std::string> val(attr_desc.strings_size());
for (int i = 0; i < attr_desc.strings_size(); ++i) {
val[i] = attr_desc.strings(i);
}
attr.Set<std::vector<std::string>>(val);
break;
}
case proto::AttrType::LONG: {
attr.Set<int64_t>(attr_desc.l());
break;
}
default:
// std::cout << " not support " << std::endl;
break;
}
// std::cout << "end get attr value" << std::endl;
return attr;
}
/* /*
* PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__INT = 0, * PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__INT = 0,
......
...@@ -14,8 +14,6 @@ limitations under the License. */ ...@@ -14,8 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include "framework/framework.pb.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#ifndef PROTOBUF_C_framework_2eproto__INCLUDED #ifndef PROTOBUF_C_framework_2eproto__INCLUDED
#define PROTOBUF_C_framework_2eproto__INCLUDED #define PROTOBUF_C_framework_2eproto__INCLUDED
#include "protobuf-c.h" #include "common/protobuf-c.h"
PROTOBUF_C__BEGIN_DECLS PROTOBUF_C__BEGIN_DECLS
......
此差异已折叠。
此差异已折叠。
...@@ -244,58 +244,5 @@ void AppendLoD(LoD *lod, const LoD &lod_length) { ...@@ -244,58 +244,5 @@ void AppendLoD(LoD *lod, const LoD &lod_length) {
} }
} }
void SerializeToStream(std::ostream &os, const LoDTensor &tensor) {
{ // the 1st field, uint32_t version for LoDTensor
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
}
{
// the 2st field, LoD information
// uint64_t lod_level
// uint64_t lod_level_1 size in byte.
// int* lod_level_1 data
// ...
auto lod = tensor.lod();
uint64_t size = lod.size();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
for (auto &each : lod) {
size = each.size() * sizeof(framework::LoD::value_type::value_type);
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
os.write(reinterpret_cast<const char *>(each.data()),
static_cast<std::streamsize>(size));
}
}
// the 3st field, Tensor
TensorToStream(os, static_cast<Tensor>(tensor));
}
void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
{
// the 1st field, unit32_t version for LoDTensor
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
// PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is
// supported");
}
{
// the 2st field, LoD information
uint64_t lod_level;
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
lod.resize(lod_level);
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
std::vector<size_t> tmp(size / sizeof(size_t));
is.read(reinterpret_cast<char *>(tmp.data()),
static_cast<std::streamsize>(size));
lod[i] = tmp;
}
}
// the 3st filed, Tensor
TensorFromStream(is, static_cast<Tensor *>(tensor));
}
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include <string> #include <string>
#include "common/log.h" #include "common/log.h"
#include "common/type_define.h" #include "common/type_define.h"
#include "framework/framework.pb.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "framework/data_layout.h" #include "framework/data_layout.h"
#include "framework/framework.pb.h" #include "framework/program/tensor_desc.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
...@@ -33,10 +33,10 @@ struct OpKernelType { ...@@ -33,10 +33,10 @@ struct OpKernelType {
// place, data_type, library_type kinds less than 2^8 // place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8; constexpr static int LEFT_SHIFT = 8;
proto::VarType::Type data_type_; VarType_Type data_type_;
DataLayout data_layout_; DataLayout data_layout_;
OpKernelType(proto::VarType::Type data_type, OpKernelType(VarType_Type data_type,
DataLayout data_layout = DataLayout::kAnyLayout) DataLayout data_layout = DataLayout::kAnyLayout)
: data_type_(data_type), data_layout_(data_layout) {} : data_type_(data_type), data_layout_(data_layout) {}
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include "framework/framework.pb.h"
#include "framework/framework.pb-c.h" #include "framework/framework.pb-c.h"
#include "framework/program/op_desc.h" #include "framework/program/op_desc.h"
#include "framework/program/var_desc.h" #include "framework/program/var_desc.h"
......
...@@ -19,7 +19,6 @@ limitations under the License. */ ...@@ -19,7 +19,6 @@ limitations under the License. */
#include "common/log.h" #include "common/log.h"
#include "common/type_define.h" #include "common/type_define.h"
#include "framework/framework.pb.h"
#include "framework/framework.pb-c.h" #include "framework/framework.pb-c.h"
#include "framework/paddle_mobile_object.h" #include "framework/paddle_mobile_object.h"
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include "framework/framework.pb.h"
#include "framework/framework.pb-c.h" #include "framework/framework.pb-c.h"
#include "framework/program/tensor_desc.h" #include "framework/program/tensor_desc.h"
#include "framework/paddle_mobile_object.h" #include "framework/paddle_mobile_object.h"
......
...@@ -132,37 +132,6 @@ bool TensorContainsInf(const framework::Tensor &tensor) { ...@@ -132,37 +132,6 @@ bool TensorContainsInf(const framework::Tensor &tensor) {
return Any(tensor, predicate); return Any(tensor, predicate);
} }
void TensorToStream(std::ostream &os, const Tensor &tensor) {
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
}
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
proto::VarType::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims());
auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0);
std::copy(dims.begin(), dims.end(), pb_dims->begin());
int32_t size = desc.ByteSize();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
auto out = desc.SerializeAsString();
os.write(out.data(), size);
}
{ // the 3rd field, tensor data
uint64_t size = tensor.memory_size();
auto *data_ptr = tensor.data<void>();
// PADDLE_ENFORCE(size <
// std::numeric_limits<std::streamsize>::max(),
// "Index overflow when writing tensor");
os.write(static_cast<const char *>(data_ptr),
static_cast<std::streamsize>(size));
}
}
struct DeserializedDataFunctor { struct DeserializedDataFunctor {
DeserializedDataFunctor(void **buf, Tensor *tensor) DeserializedDataFunctor(void **buf, Tensor *tensor)
: buf_(buf), tensor_(tensor) {} : buf_(buf), tensor_(tensor) {}
...@@ -176,32 +145,5 @@ struct DeserializedDataFunctor { ...@@ -176,32 +145,5 @@ struct DeserializedDataFunctor {
Tensor *tensor_; Tensor *tensor_;
}; };
void TensorFromStream(std::istream &is, framework::Tensor *tensor) {
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
// PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
proto::VarType::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char *>(buf.get()), size);
// PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
// "Cannot parse tensor desc");
}
{ // read tensor
std::vector<int64_t> dims;
dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
tensor->Resize(framework::make_ddim(dims));
void *buf;
framework::VisitDataType(desc.data_type(),
DeserializedDataFunctor(&buf, tensor));
is.read(static_cast<char *>(buf), tensor->memory_size());
}
}
} // namespace framework } // namespace framework
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "framework.pb.h"
#include "memory/t_malloc.h" #include "memory/t_malloc.h"
#include "platform/data_type.h" #include "platform/data_type.h"
#include "tensor.h" #include "tensor.h"
......
...@@ -22,7 +22,6 @@ limitations under the License. */ ...@@ -22,7 +22,6 @@ limitations under the License. */
#include "framework/tensor.h" #include "framework/tensor.h"
#include "framework/operator.h" #include "framework/operator.h"
#include "framework/lod_tensor.h" #include "framework/lod_tensor.h"
#include "framework/framework.pb.h"
#include "framework/framework.pb-c.h" #include "framework/framework.pb-c.h"
#include "framework/program/var_desc.h" #include "framework/program/var_desc.h"
#include "framework/program/program_desc.h" #include "framework/program/program_desc.h"
...@@ -110,6 +109,14 @@ void Loader<Dtype, P>::LoadVar(framework::Variable *variable, const framework::V ...@@ -110,6 +109,14 @@ void Loader<Dtype, P>::LoadVar(framework::Variable *variable, const framework::V
const framework::TensorDesc &desc = var_desc.Tensor_desc(); const framework::TensorDesc &desc = var_desc.Tensor_desc();
PaddleMobile__Framework__Proto__VarType__TensorDesc *tensor_desc = NULL;
// void *v;
// PaddleMobile__Framework__Proto__VarType__TensorDesc_Closure()(tensor_desc, buf.get());
// DLOG << "PaddleMobile__Framework__Proto__VarType__TensorDesc_Closure- " << tensor_desc;
// framework::TensorDesc &tensor_desc = variable-> // framework::TensorDesc &tensor_desc = variable->
// PaddleMobile__Framework__Proto__ProgramDesc *c_program; // PaddleMobile__Framework__Proto__ProgramDesc *c_program;
// uint8_t *proto_buf = NULL; // uint8_t *proto_buf = NULL;
...@@ -240,7 +247,7 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p) : program_(p) { ...@@ -240,7 +247,7 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p) : program_(p) {
} }
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
void Executor<Dtype, P>::LoadMemory(framework::LoDTensor *tensor, void Executor<Dtype, P>::LoadMemory(const framework::VarDesc var_desc, framework::LoDTensor *tensor,
const std::string &file_path) { const std::string &file_path) {
std::ifstream is(file_path); std::ifstream is(file_path);
PADDLE_MOBILE_ENFORCE(is.is_open(), "open file: %s failed", PADDLE_MOBILE_ENFORCE(is.is_open(), "open file: %s failed",
...@@ -281,39 +288,36 @@ void Executor<Dtype, P>::LoadMemory(framework::LoDTensor *tensor, ...@@ -281,39 +288,36 @@ void Executor<Dtype, P>::LoadMemory(framework::LoDTensor *tensor,
std::unique_ptr<char[]> buf(new char[size]); std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char *>(buf.get()), size); is.read(reinterpret_cast<char *>(buf.get()), size);
framework::proto::VarType::TensorDesc desc; const framework::TensorDesc &desc = var_desc.Tensor_desc();
desc.ParseFromArray(buf.get(), size);
int memory_size = 1; int memory_size = 1;
for (auto l : desc.dims()) { for (auto l : desc.Dims()) {
memory_size *= l; memory_size *= l;
} }
std::vector<int64_t> dims; tensor->Resize(framework::make_ddim(desc.Dims()));
dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
tensor->Resize(framework::make_ddim(dims));
void *memory = tensor; void *memory = tensor;
int type_size = 0; int type_size = 0;
switch (desc.data_type()) { switch (desc.DataType()) {
case framework::proto::VarType::FP16: case framework::VARTYPE_TYPE_FP16:
type_size = 2; type_size = 2;
break; break;
case framework::proto::VarType::FP32: case framework::VARTYPE_TYPE_FP32:
type_size = 4; type_size = 4;
memory = tensor->mutable_data<float>(); memory = tensor->mutable_data<float>();
break; break;
case framework::proto::VarType::FP64: case framework::VARTYPE_TYPE_FP64:
type_size = 8; type_size = 8;
break; break;
case framework::proto::VarType::INT32: case framework::VARTYPE_TYPE_INT32:
type_size = 4; type_size = 4;
break; break;
case framework::proto::VarType::INT64: case framework::VARTYPE_TYPE_INT64:
type_size = 8; type_size = 8;
break; break;
case framework::proto::VarType::BOOL: case framework::VARTYPE_TYPE_BOOL:
type_size = 1; type_size = 1;
break; break;
default: default:
...@@ -331,7 +335,7 @@ void Executor<Dtype, P>::InitMemory() { ...@@ -331,7 +335,7 @@ void Executor<Dtype, P>::InitMemory() {
auto var = program_.scope->Var(var_desc->Name()); auto var = program_.scope->Var(var_desc->Name());
if (var_desc->Persistable()) { if (var_desc->Persistable()) {
auto tensor = var->template GetMutable<framework::LoDTensor>(); auto tensor = var->template GetMutable<framework::LoDTensor>();
LoadMemory(tensor, program_.model_path + "/" + var_desc->Name()); LoadMemory(*var_desc, tensor, program_.model_path + "/" + var_desc->Name());
} else { } else {
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) { if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
auto tensor = var->template GetMutable<framework::Tensor>(); auto tensor = var->template GetMutable<framework::Tensor>();
......
...@@ -52,7 +52,7 @@ class Executor { ...@@ -52,7 +52,7 @@ class Executor {
protected: protected:
void InitMemory(); void InitMemory();
void LoadMemory(framework::LoDTensor *tensor, const std::string &file_path); void LoadMemory(const framework::VarDesc var_desc, framework::LoDTensor *tensor, const std::string &file_path);
framework::Program<Dtype> program_; framework::Program<Dtype> program_;
std::shared_ptr<framework::ProgramDesc> to_predict_program_; std::shared_ptr<framework::ProgramDesc> to_predict_program_;
void predict(const framework::Tensor &t, int block_id); void predict(const framework::Tensor &t, int block_id);
......
...@@ -16,12 +16,13 @@ limitations under the License. */ ...@@ -16,12 +16,13 @@ limitations under the License. */
#include <string> #include <string>
#include <typeindex> #include <typeindex>
#include "framework/framework.pb.h"
#include "framework/program/tensor_desc.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
inline proto::VarType::Type ToDataType(std::type_index type) { inline VarType_Type ToDataType(std::type_index type) {
/*if (typeid(platform::float16).hash_code() == type.hash_code()) { /*if (typeid(platform::float16).hash_code() == type.hash_code()) {
return proto::VarType::FP16; return proto::VarType::FP16;
} else */ } else */
...@@ -31,34 +32,34 @@ inline proto::VarType::Type ToDataType(std::type_index type) { ...@@ -31,34 +32,34 @@ inline proto::VarType::Type ToDataType(std::type_index type) {
// One fix to this is to replace float with const float because // One fix to this is to replace float with const float because
// typeid(T) == typeid(const T) // typeid(T) == typeid(const T)
// http://en.cppreference.com/w/cpp/language/typeid // http://en.cppreference.com/w/cpp/language/typeid
return proto::VarType::FP32; return VARTYPE_TYPE_FP32;
} else if (typeid(const double).hash_code() == type.hash_code()) { } else if (typeid(const double).hash_code() == type.hash_code()) {
return proto::VarType::FP64; return VARTYPE_TYPE_FP64;
} else if (typeid(const int).hash_code() == type.hash_code()) { } else if (typeid(const int).hash_code() == type.hash_code()) {
return proto::VarType::INT32; return VARTYPE_TYPE_INT32;
} else if (typeid(const int64_t).hash_code() == type.hash_code()) { } else if (typeid(const int64_t).hash_code() == type.hash_code()) {
return proto::VarType::INT64; return VARTYPE_TYPE_INT64;
} else if (typeid(const bool).hash_code() == type.hash_code()) { } else if (typeid(const bool).hash_code() == type.hash_code()) {
return proto::VarType::BOOL; return VARTYPE_TYPE_BOOL;
} else { } else {
// PADDLE_THROW("Not supported"); // PADDLE_THROW("Not supported");
// std::cout << "Not supported"; // std::cout << "Not supported";
} }
} }
inline std::type_index ToTypeIndex(proto::VarType::Type type) { inline std::type_index ToTypeIndex(VarType_Type type) {
switch (type) { switch (type) {
// case proto::VarType::FP16: // case proto::VarType::FP16:
// return typeid(platform::float16); // return typeid(platform::float16);
case proto::VarType::FP32: case VARTYPE_TYPE_FP32:
return typeid(float); return typeid(float);
case proto::VarType::FP64: case VARTYPE_TYPE_FP64:
return typeid(double); return typeid(double);
case proto::VarType::INT32: case VARTYPE_TYPE_INT32:
return typeid(int); return typeid(int);
case proto::VarType::INT64: case VARTYPE_TYPE_INT64:
return typeid(int64_t); return typeid(int64_t);
case proto::VarType::BOOL: case VARTYPE_TYPE_BOOL:
return typeid(bool); return typeid(bool);
default: default:
// PADDLE_THROW("Not support type %d", type); // PADDLE_THROW("Not support type %d", type);
...@@ -67,24 +68,24 @@ inline std::type_index ToTypeIndex(proto::VarType::Type type) { ...@@ -67,24 +68,24 @@ inline std::type_index ToTypeIndex(proto::VarType::Type type) {
} }
template <typename Visitor> template <typename Visitor>
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { inline void VisitDataType(VarType_Type type, Visitor visitor) {
switch (type) { switch (type) {
// case proto::VarType::FP16: // case proto::VarType::FP16:
// visitor.template operator()<platform::float16>(); // visitor.template operator()<platform::float16>();
// break; // break;
case proto::VarType::FP32: case VARTYPE_TYPE_FP32:
visitor.template operator()<float>(); visitor.template operator()<float>();
break; break;
case proto::VarType::FP64: case VARTYPE_TYPE_FP64:
visitor.template operator()<double>(); visitor.template operator()<double>();
break; break;
case proto::VarType::INT32: case VARTYPE_TYPE_INT32:
visitor.template operator()<int>(); visitor.template operator()<int>();
break; break;
case proto::VarType::INT64: case VARTYPE_TYPE_INT64:
visitor.template operator()<int64_t>(); visitor.template operator()<int64_t>();
break; break;
case proto::VarType::BOOL: case VARTYPE_TYPE_BOOL:
visitor.template operator()<bool>(); visitor.template operator()<bool>();
break; break;
default: default:
...@@ -93,21 +94,21 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { ...@@ -93,21 +94,21 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
} }
} }
inline std::string DataTypeToString(const proto::VarType::Type type) { inline std::string DataTypeToString(const VarType_Type type) {
switch (type) { switch (type) {
case proto::VarType::FP16: case VARTYPE_TYPE_FP16:
return "float16"; return "float16";
case proto::VarType::FP32: case VARTYPE_TYPE_FP32:
return "float32"; return "float32";
case proto::VarType::FP64: case VARTYPE_TYPE_FP64:
return "float64"; return "float64";
case proto::VarType::INT16: case VARTYPE_TYPE_INT16:
return "int16"; return "int16";
case proto::VarType::INT32: case VARTYPE_TYPE_INT32:
return "int32"; return "int32";
case proto::VarType::INT64: case VARTYPE_TYPE_INT64:
return "int64"; return "int64";
case proto::VarType::BOOL: case VARTYPE_TYPE_BOOL:
return "bool"; return "bool";
default: default:
// PADDLE_THROW("Not support type %d", type); // PADDLE_THROW("Not support type %d", type);
...@@ -116,7 +117,7 @@ inline std::string DataTypeToString(const proto::VarType::Type type) { ...@@ -116,7 +117,7 @@ inline std::string DataTypeToString(const proto::VarType::Type type) {
} }
inline std::ostream &operator<<(std::ostream &out, inline std::ostream &operator<<(std::ostream &out,
const proto::VarType::Type &type) { const VarType_Type &type) {
out << DataTypeToString(type); out << DataTypeToString(type);
return out; return out;
} }
......
...@@ -21,7 +21,6 @@ limitations under the License. */ ...@@ -21,7 +21,6 @@ limitations under the License. */
#include "./test_helper.h" #include "./test_helper.h"
#include "common/enforce.h" #include "common/enforce.h"
#include "common/log.h" #include "common/log.h"
#include "framework/framework.pb.h"
#include "framework/lod_tensor.h" #include "framework/lod_tensor.h"
#include "framework/operator.h" #include "framework/operator.h"
#include "framework/program/block_desc.h" #include "framework/program/block_desc.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册