未验证 提交 80d5f6fb 编写于 作者: Y Yanzhan Yang 提交者: GitHub

fix variant memory leak (#1951)

* fix variant memory leak test=develop

* fix variant memory leak & fix style test=develop

* fix cpplint test=develop
上级 60bbc691
...@@ -16,7 +16,9 @@ limitations under the License. */ ...@@ -16,7 +16,9 @@ limitations under the License. */
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include <memory>
#include <string> #include <string>
#include <utility>
#include "common/enforce.h" #include "common/enforce.h"
#include "common/log.h" #include "common/log.h"
#include "common/type_define.h" #include "common/type_define.h"
...@@ -30,41 +32,33 @@ struct IDToType { ...@@ -30,41 +32,33 @@ struct IDToType {
template <typename F, typename... Ts> template <typename F, typename... Ts>
struct VariantHelper { struct VariantHelper {
static const size_t size = sizeof(F) > VariantHelper<Ts...>::size inline static void Destroy(kTypeId_t type, void *raw_ptr) {
? sizeof(F)
: VariantHelper<Ts...>::size;
inline static void Destroy(kTypeId_t type, void *data) {
if (type == type_id<F>()) { if (type == type_id<F>()) {
reinterpret_cast<F *>(data)->~F(); auto ptr = reinterpret_cast<F *>(raw_ptr);
delete ptr;
} else { } else {
VariantHelper<Ts...>::Destroy(type, data); VariantHelper<Ts...>::Destroy(type, raw_ptr);
} }
} }
}; };
template <typename F> template <typename F>
struct VariantHelper<F> { struct VariantHelper<F> {
static const size_t size = sizeof(F); inline static void Destroy(kTypeId_t type, void *raw_ptr) {
inline static void Destroy(kTypeId_t type, void *data) {
if (type == type_id<F>()) { if (type == type_id<F>()) {
// reinterpret_cast<F*>(data)->~F(); auto ptr = reinterpret_cast<F *>(raw_ptr);
} else { delete ptr;
// std::cout << "未匹配到 " << std::endl;
} }
} }
}; };
template <size_t size> template <typename... Ts>
class RawData { struct VariantDeleter {
public: kTypeId_t type_ = type_id<void>().hash_code();
char data[size]; // NOLINT explicit VariantDeleter(kTypeId_t type) { type_ = type; }
RawData() {} void operator()(void *raw_ptr) {
RawData(const RawData &raw_data) { memcpy(data, raw_data.data, size); } // DLOG << "variant delete: " << type_ << " " << raw_ptr;
VariantHelper<Ts...>::Destroy(type_, raw_ptr);
RawData &operator=(const RawData &raw_data) {
memcpy(data, raw_data.data, size);
return *this;
} }
}; };
...@@ -78,43 +72,21 @@ struct Variant { ...@@ -78,43 +72,21 @@ struct Variant {
} }
virtual ~Variant() { virtual ~Variant() {
// helper::Destroy(type_id, &data); // DLOG << "variant deinit: " << type_ << " " << (void *)data_.get();
data_.reset();
} }
template <typename T, typename... Args> template <typename T, typename... Args>
void Set(Args &&... args) { void Set(Args &&... args) {
helper::Destroy(type_, data_.data); auto raw_ptr = new T(std::forward<Args>(args)...);
new (data_.data) T(std::forward<Args>(args)...);
type_ = type_id<T>().hash_code(); type_ = type_id<T>().hash_code();
} // DLOG << "variant new: " << type_ << " " << (void *)raw_ptr;
data_.reset(raw_ptr, VariantDeleter<Ts...>(type_));
void SetString(const std::string &string) {
helper::Destroy(type_, data_.data);
type_ = type_id<std::string>().hash_code();
strcpy(data_.data, string.c_str()); // NOLINT
}
std::string GetString() const {
if (type_ == type_id<std::string>()) {
return std::string(data_.data);
} else {
PADDLE_MOBILE_THROW_EXCEPTION(
" bad cast in variant data type not a string ");
exit(0);
}
} }
template <typename T> template <typename T>
T &Get() const { T &Get() const {
if (type_ == type_id<std::string>()) { return *const_cast<T *>(reinterpret_cast<const T *>(data_.get()));
PADDLE_MOBILE_THROW_EXCEPTION(
"Please use getString to get an string (to avoid of an issue with "
"gcc "
"stl lib with string copy)");
exit(0);
} else {
return *const_cast<T *>(reinterpret_cast<const T *>(data_.data));
}
} }
kTypeId_t TypeId() const { return type_; } kTypeId_t TypeId() const { return type_; }
...@@ -123,8 +95,7 @@ struct Variant { ...@@ -123,8 +95,7 @@ struct Variant {
static inline kTypeId_t invalid_type() { return type_id<void>().hash_code(); } static inline kTypeId_t invalid_type() { return type_id<void>().hash_code(); }
typedef VariantHelper<Ts...> helper; typedef VariantHelper<Ts...> helper;
kTypeId_t type_ = type_id<void>().hash_code(); kTypeId_t type_ = type_id<void>().hash_code();
// todo use an anto size to suite this. std::shared_ptr<void> data_;
RawData<64> data_;
}; };
template <typename T> template <typename T>
......
...@@ -51,7 +51,7 @@ class Attribute { ...@@ -51,7 +51,7 @@ class Attribute {
break; break;
} }
case PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__STRING: { case PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__STRING: {
attr.SetString(std::string(attr_desc->s)); attr.Set<std::string>(attr_desc->s);
break; break;
} }
case PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BOOLEANS: { case PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BOOLEANS: {
...@@ -119,12 +119,7 @@ class Attribute { ...@@ -119,12 +119,7 @@ class Attribute {
return variant_.Get<T>(); return variant_.Get<T>();
} }
Attribute &SetString(std::string string) { std::string GetString() const { return variant_.Get<std::string>(); }
variant_.SetString(string);
return *this;
}
std::string GetString() const { return variant_.GetString(); }
template <typename Vistor> template <typename Vistor>
static typename Vistor::type_t ApplyVistor(Vistor vistor, Attribute attr) { static typename Vistor::type_t ApplyVistor(Vistor vistor, Attribute attr) {
...@@ -133,7 +128,7 @@ class Attribute { ...@@ -133,7 +128,7 @@ class Attribute {
} else if (attr.variant_.TypeId() == type_id<float>()) { // NOLINT } else if (attr.variant_.TypeId() == type_id<float>()) { // NOLINT
return vistor(attr.variant_.Get<float>()); return vistor(attr.variant_.Get<float>());
} else if (attr.variant_.TypeId() == type_id<string>()) { } else if (attr.variant_.TypeId() == type_id<string>()) {
return vistor(attr.variant_.GetString()); return vistor(attr.variant_.Get<std::string>());
} else if (attr.variant_.TypeId() == type_id<vector<int>>()) { } else if (attr.variant_.TypeId() == type_id<vector<int>>()) {
return vistor(attr.variant_.Get<vector<int>>()); return vistor(attr.variant_.Get<vector<int>>());
} else if (attr.variant_.TypeId() == type_id<vector<float>>()) { } else if (attr.variant_.TypeId() == type_id<vector<float>>()) {
......
...@@ -17,8 +17,7 @@ limitations under the License. */ ...@@ -17,8 +17,7 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
const DDim &CLImageConverterDefault::InitImageDimInfoWith( DDim CLImageConverterDefault::InitImageDimInfoWith(const DDim &tensor_dim) {
const DDim &tensor_dim) {
size_t new_dims[] = {1, 1, 1, 1}; size_t new_dims[] = {1, 1, 1, 1};
for (int j = 0; j < tensor_dim.size(); ++j) { for (int j = 0; j < tensor_dim.size(); ++j) {
new_dims[4 - tensor_dim.size() + j] = tensor_dim[j]; new_dims[4 - tensor_dim.size() + j] = tensor_dim[j];
...@@ -119,8 +118,7 @@ void CLImageConverterDefault::ImageToNCHW(half_t *image, float *tensor, ...@@ -119,8 +118,7 @@ void CLImageConverterDefault::ImageToNCHW(half_t *image, float *tensor,
} }
} }
const DDim &CLImageConverterFolder::InitImageDimInfoWith( DDim CLImageConverterFolder::InitImageDimInfoWith(const DDim &tensor_dim) {
const DDim &tensor_dim) {
if (tensor_dim.size() <= 2) { if (tensor_dim.size() <= 2) {
int tdim[2] = {1, 1}; int tdim[2] = {1, 1};
if (tensor_dim.size() == 1) { if (tensor_dim.size() == 1) {
...@@ -218,8 +216,7 @@ void CLImageConverterFolder::ImageToNCHW(half_t *image, float *tensor, ...@@ -218,8 +216,7 @@ void CLImageConverterFolder::ImageToNCHW(half_t *image, float *tensor,
} }
} }
const DDim &CLImageConverterNWBlock::InitImageDimInfoWith( DDim CLImageConverterNWBlock::InitImageDimInfoWith(const DDim &tensor_dim) {
const DDim &tensor_dim) {
PADDLE_MOBILE_ENFORCE(tensor_dim.size() == 4, " tensor dim is not 4"); PADDLE_MOBILE_ENFORCE(tensor_dim.size() == 4, " tensor dim is not 4");
size_t N, C, H, W; size_t N, C, H, W;
N = tensor_dim[0]; N = tensor_dim[0];
...@@ -297,8 +294,7 @@ void CLImageConverterNWBlock::ImageToNCHW(half_t *image, float *tensor, ...@@ -297,8 +294,7 @@ void CLImageConverterNWBlock::ImageToNCHW(half_t *image, float *tensor,
DLOG << " init done"; DLOG << " init done";
} }
const DDim &CLImageConverterDWBlock::InitImageDimInfoWith( DDim CLImageConverterDWBlock::InitImageDimInfoWith(const DDim &tensor_dim) {
const DDim &tensor_dim) {
PADDLE_MOBILE_ENFORCE(tensor_dim.size() == 4, " tensor dim is not 4"); PADDLE_MOBILE_ENFORCE(tensor_dim.size() == 4, " tensor dim is not 4");
size_t N, C, H, W; size_t N, C, H, W;
N = tensor_dim[0]; N = tensor_dim[0];
...@@ -389,8 +385,7 @@ void CLImageConverterDWBlock::ImageToNCHW(half_t *image, float *tensor, ...@@ -389,8 +385,7 @@ void CLImageConverterDWBlock::ImageToNCHW(half_t *image, float *tensor,
} }
} }
const DDim &CLImageConverterNormal::InitImageDimInfoWith( DDim CLImageConverterNormal::InitImageDimInfoWith(const DDim &tensor_dim) {
const DDim &tensor_dim) {
PADDLE_MOBILE_ENFORCE(tensor_dim.size() <= 4 && tensor_dim.size() > 0, PADDLE_MOBILE_ENFORCE(tensor_dim.size() <= 4 && tensor_dim.size() > 0,
"tensor dim is not support "); "tensor dim is not support ");
size_t new_dims[] = {1, 1, 1, 1}; size_t new_dims[] = {1, 1, 1, 1};
...@@ -428,7 +423,7 @@ void CLImageConverterNormal::ImageToNCHW(half_t *image, float *tensor, ...@@ -428,7 +423,7 @@ void CLImageConverterNormal::ImageToNCHW(half_t *image, float *tensor,
default_converter.ImageToNCHW(image, tensor, image_dim, tensor_dim); default_converter.ImageToNCHW(image, tensor, image_dim, tensor_dim);
} }
const DDim &CLImageConverterWinoTransWeight::InitImageDimInfoWith( DDim CLImageConverterWinoTransWeight::InitImageDimInfoWith(
const DDim &tensor_dim) { const DDim &tensor_dim) {
PADDLE_MOBILE_ENFORCE(tensor_dim.size() == 4, " tensor dim is not 4"); PADDLE_MOBILE_ENFORCE(tensor_dim.size() == 4, " tensor dim is not 4");
size_t N, C, H, W; size_t N, C, H, W;
...@@ -448,7 +443,7 @@ void CLImageConverterWinoTransWeight::ImageToNCHW(half_t *image, float *tensor, ...@@ -448,7 +443,7 @@ void CLImageConverterWinoTransWeight::ImageToNCHW(half_t *image, float *tensor,
const DDim &image_dim, const DDim &image_dim,
const DDim &tensor_dim) {} const DDim &tensor_dim) {}
const DDim &CLImageConverterConv2dTransposeTransWeight::InitImageDimInfoWith( DDim CLImageConverterConv2dTransposeTransWeight::InitImageDimInfoWith(
const DDim &tensor_dim) { const DDim &tensor_dim) {
size_t new_dims[] = {1, 1, 1, 1}; size_t new_dims[] = {1, 1, 1, 1};
for (int j = 0; j < tensor_dim.size(); ++j) { for (int j = 0; j < tensor_dim.size(); ++j) {
......
...@@ -27,12 +27,12 @@ class CLImageConverterBase { ...@@ -27,12 +27,12 @@ class CLImageConverterBase {
virtual void ImageToNCHW(half_t *image, float *nchw, const DDim &image_dim, virtual void ImageToNCHW(half_t *image, float *nchw, const DDim &image_dim,
const DDim &tensor_dim) = 0; const DDim &tensor_dim) = 0;
virtual const DDim &InitImageDimInfoWith(const DDim &tensor_dim) = 0; virtual DDim InitImageDimInfoWith(const DDim &tensor_dim) = 0;
}; };
class CLImageConverterDefault : public CLImageConverterBase { class CLImageConverterDefault : public CLImageConverterBase {
public: public:
const DDim &InitImageDimInfoWith(const DDim &tensor_dim); DDim InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *nchw, half_t *image, const DDim &tensor_dim); void NCHWToImage(float *nchw, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim, void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim); const DDim &tensor_dim);
...@@ -40,7 +40,7 @@ class CLImageConverterDefault : public CLImageConverterBase { ...@@ -40,7 +40,7 @@ class CLImageConverterDefault : public CLImageConverterBase {
class CLImageConverterFolder : public CLImageConverterBase { class CLImageConverterFolder : public CLImageConverterBase {
public: public:
const DDim &InitImageDimInfoWith(const DDim &tensor_dim); DDim InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim); void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim, void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim); const DDim &tensor_dim);
...@@ -65,7 +65,7 @@ class CLImageConverterFolder : public CLImageConverterBase { ...@@ -65,7 +65,7 @@ class CLImageConverterFolder : public CLImageConverterBase {
class CLImageConverterNormal : public CLImageConverterBase { class CLImageConverterNormal : public CLImageConverterBase {
public: public:
const DDim &InitImageDimInfoWith(const DDim &tensor_dim); DDim InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim); void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim, void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim); const DDim &tensor_dim);
...@@ -89,13 +89,13 @@ class CLImageConverterNormal : public CLImageConverterBase { ...@@ -89,13 +89,13 @@ class CLImageConverterNormal : public CLImageConverterBase {
}; };
class CLImageConverterNWBlock : public CLImageConverterBase { class CLImageConverterNWBlock : public CLImageConverterBase {
const DDim &InitImageDimInfoWith(const DDim &tensor_dim); DDim InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim); void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim, void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim); const DDim &tensor_dim);
}; };
class CLImageConverterDWBlock : public CLImageConverterBase { class CLImageConverterDWBlock : public CLImageConverterBase {
const DDim &InitImageDimInfoWith(const DDim &tensor_dim); DDim InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim); void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim, void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim); const DDim &tensor_dim);
...@@ -103,7 +103,7 @@ class CLImageConverterDWBlock : public CLImageConverterBase { ...@@ -103,7 +103,7 @@ class CLImageConverterDWBlock : public CLImageConverterBase {
class CLImageConverterWinoTransWeight : public CLImageConverterBase { class CLImageConverterWinoTransWeight : public CLImageConverterBase {
public: public:
const DDim &InitImageDimInfoWith(const DDim &tensor_dim); DDim InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim); void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim, void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim); const DDim &tensor_dim);
...@@ -111,7 +111,7 @@ class CLImageConverterWinoTransWeight : public CLImageConverterBase { ...@@ -111,7 +111,7 @@ class CLImageConverterWinoTransWeight : public CLImageConverterBase {
class CLImageConverterConv2dTransposeTransWeight : public CLImageConverterBase { class CLImageConverterConv2dTransposeTransWeight : public CLImageConverterBase {
public: public:
const DDim &InitImageDimInfoWith(const DDim &tensor_dim); DDim InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim); void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim, void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim); const DDim &tensor_dim);
......
...@@ -62,7 +62,6 @@ struct DDim { ...@@ -62,7 +62,6 @@ struct DDim {
return vistor(d.var.Get<Dim<9>>()); return vistor(d.var.Get<Dim<9>>());
} else { } else {
PADDLE_MOBILE_ENFORCE(false, " dim not support"); PADDLE_MOBILE_ENFORCE(false, " dim not support");
exit(0);
} }
} }
...@@ -73,6 +72,8 @@ struct DDim { ...@@ -73,6 +72,8 @@ struct DDim {
var.Set<Dim<D>>(in); var.Set<Dim<D>>(in);
} }
DDim(const DDim &in) { setNewDim(in); }
/*implicit*/ DDim(std::initializer_list<int64_t> init_list); /*implicit*/ DDim(std::initializer_list<int64_t> init_list);
template <int D> template <int D>
...@@ -81,11 +82,42 @@ struct DDim { ...@@ -81,11 +82,42 @@ struct DDim {
return *this; return *this;
} }
DDim &operator=(const DDim &in) {
setNewDim(in);
return *this;
}
void setNewDim(const DDim &d) {
if (d.var.TypeId() == type_id<Dim<0>>()) {
return var.Set<Dim<0>>(d.var.Get<Dim<0>>());
} else if (d.var.TypeId() == type_id<Dim<1>>()) {
return var.Set<Dim<1>>(d.var.Get<Dim<1>>());
} else if (d.var.TypeId() == type_id<Dim<2>>()) {
return var.Set<Dim<2>>(d.var.Get<Dim<2>>());
} else if (d.var.TypeId() == type_id<Dim<3>>()) {
return var.Set<Dim<3>>(d.var.Get<Dim<3>>());
} else if (d.var.TypeId() == type_id<Dim<4>>()) {
return var.Set<Dim<4>>(d.var.Get<Dim<4>>());
} else if (d.var.TypeId() == type_id<Dim<5>>()) {
return var.Set<Dim<5>>(d.var.Get<Dim<5>>());
} else if (d.var.TypeId() == type_id<Dim<6>>()) {
return var.Set<Dim<6>>(d.var.Get<Dim<6>>());
} else if (d.var.TypeId() == type_id<Dim<7>>()) {
return var.Set<Dim<7>>(d.var.Get<Dim<7>>());
} else if (d.var.TypeId() == type_id<Dim<8>>()) {
return var.Set<Dim<8>>(d.var.Get<Dim<8>>());
} else if (d.var.TypeId() == type_id<Dim<9>>()) {
return var.Set<Dim<9>>(d.var.Get<Dim<9>>());
} else {
PADDLE_MOBILE_ENFORCE(false, " dim not support");
}
}
int64_t &operator[](int idx); int64_t &operator[](int idx);
int64_t operator[](int idx) const; int64_t operator[](int idx) const;
DDimVar getVar() { return var; } DDimVar getVar() const { return var; }
bool operator==(DDim d) const; bool operator==(DDim d) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册