未验证 提交 1d6bfdca 编写于 作者: N nihui 提交者: GitHub

fix pnnx pass on fp16 weight, common fp16 conversion routines (#4743)

上级 0dbe5a71
......@@ -29,6 +29,7 @@
#endif
#include "storezip.h"
#include "utils.h"
namespace pnnx {
......@@ -429,13 +430,7 @@ Attribute::Attribute(const at::Tensor& t)
if (shape.size() > 0)
{
int size = shape[0];
for (size_t i = 1; i < shape.size(); i++)
{
size *= shape[i];
}
data.resize(size * type_to_elemsize(type));
data.resize(elemcount() * type_to_elemsize(type));
memcpy((void*)data.data(), (const void*)t.cpu().contiguous().data_ptr(), data.size());
}
}
......@@ -448,14 +443,93 @@ Attribute::Attribute(const std::initializer_list<int>& _shape, const std::vector
if (shape.size() > 0)
{
int size = shape[0];
for (size_t i = 1; i < shape.size(); i++)
data.resize(elemcount() * type_to_elemsize(type));
memcpy((void*)data.data(), (const void*)t.data(), data.size());
}
}
size_t Attribute::elemsize() const
{
return type_to_elemsize(type);
}
int Attribute::elemcount() const
{
if (shape.empty())
return 0;
int size = shape[0];
for (size_t i = 1; i < shape.size(); i++)
{
size *= shape[i];
}
return size;
}
std::vector<float> Attribute::get_float32_data() const
{
std::vector<float> v(elemcount());
if (type == 1)
{
memcpy((void*)v.data(), (const void*)data.data(), data.size());
}
else if (type == 2)
{
// f64
const double* p = (const double*)data.data();
for (size_t i = 0; i < v.size(); i++)
{
size *= shape[i];
v[i] = float(p[i]);
}
}
else if (type == 3)
{
// f16
const unsigned short* p = (const unsigned short*)data.data();
for (size_t i = 0; i < v.size(); i++)
{
v[i] = float16_to_float32(p[i]);
}
}
else
{
fprintf(stderr, "cannot convert type %d to float32 data\n", type);
}
data.resize(size * type_to_elemsize(type));
memcpy((void*)data.data(), (const void*)t.data(), data.size());
return v;
}
void Attribute::set_float32_data(const std::vector<float>& newdata)
{
data.resize(newdata.size() * elemsize());
if (type == 1)
{
memcpy((void*)data.data(), (const void*)newdata.data(), data.size());
}
else if (type == 2)
{
// f64
double* p = (double*)data.data();
for (size_t i = 0; i < newdata.size(); i++)
{
p[i] = newdata[i];
}
}
else if (type == 3)
{
// f16
unsigned short* p = (unsigned short*)data.data();
for (size_t i = 0; i < newdata.size(); i++)
{
p[i] = float32_to_float16(newdata[i]);
}
}
else
{
fprintf(stderr, "cannot convert float32 data to type %d\n", type);
}
}
......
......@@ -205,6 +205,13 @@ public:
Attribute(const std::initializer_list<int>& shape, const std::vector<float>& t);
size_t elemsize() const;
int elemcount() const;
// convenient routines for manipulate fp32/fp16 weight
std::vector<float> get_float32_data() const;
void set_float32_data(const std::vector<float>& data);
// 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 9=bool 10=c64 11=c128 12=c32
int type;
std::vector<int> shape;
......
......@@ -15,6 +15,7 @@
#include "eliminate_noop_math.h"
#include <algorithm>
#include "utils.h"
#include "pass_level2.h"
#include "pass_level4/dead_code_elimination.h"
......@@ -77,6 +78,16 @@ static bool attribute_is_all_constant(const Operator* op_attr, float vf, int vi)
return false;
}
}
else if (attr.type == 3)
{
// f16
const unsigned short* p = (const unsigned short*)attr.data.data();
for (int i = 0; i < size; i++)
{
if (float16_to_float32(p[i]) != vf)
return false;
}
}
else if (attr.type == 4)
{
const int* p = (const int*)attr.data.data();
......
......@@ -63,10 +63,10 @@ pnnx.Output output 1 0 out
bool has_bn_affine = captured_params.at("affine").b;
bool has_conv_bias = captured_params.at("bias").b;
const float* bn_running_mean = (const float*)captured_attrs.at("op_1.running_mean").data.data();
const float* bn_running_var = (const float*)captured_attrs.at("op_1.running_var").data.data();
const float* bn_weight = has_bn_affine ? (const float*)captured_attrs.at("op_1.weight").data.data() : 0;
const float* bn_bias = has_bn_affine ? (const float*)captured_attrs.at("op_1.bias").data.data() : 0;
auto bn_running_mean = captured_attrs.at("op_1.running_mean").get_float32_data();
auto bn_running_var = captured_attrs.at("op_1.running_var").get_float32_data();
auto bn_weight = has_bn_affine ? captured_attrs.at("op_1.weight").get_float32_data() : std::vector<float>();
auto bn_bias = has_bn_affine ? captured_attrs.at("op_1.bias").get_float32_data() : std::vector<float>();
// a = bias - slope * mean / sqrt(var + eps)
// b = slope / sqrt(var + eps)
......@@ -100,22 +100,20 @@ pnnx.Output output 1 0 out
{
// init bias as zero
op->attrs["bias"] = Attribute();
op->attrs["bias"].type = 1;
op->attrs["bias"].type = op->attrs["weight"].type;
op->attrs["bias"].shape = {channels};
op->attrs["bias"].data.resize(channels * sizeof(float));
memset(op->attrs["bias"].data.data(), 0, channels * sizeof(float));
op->attrs["bias"].set_float32_data(std::vector<float>(channels, 0.f));
}
float* conv_weight = (float*)op->attrs["weight"].data.data();
float* conv_bias = (float*)op->attrs["bias"].data.data();
auto conv_weight = op->attrs["weight"].get_float32_data();
auto conv_bias = op->attrs["bias"].get_float32_data();
const int outch = captured_params.at("out_channels").i;
const int weight_per_outch = op->attrs["weight"].data.size() / sizeof(float) / outch;
const int weight_per_outch = op->attrs["weight"].elemcount() / outch;
for (int i = 0; i < channels; i++)
{
float* conv_weight_outch = conv_weight + weight_per_outch * i;
float* conv_weight_outch = (float*)conv_weight.data() + weight_per_outch * i;
for (int j = 0; j < weight_per_outch; j++)
{
conv_weight_outch[j] *= b[i];
......@@ -123,6 +121,9 @@ pnnx.Output output 1 0 out
conv_bias[i] = conv_bias[i] * b[i] + a[i];
}
op->attrs["weight"].set_float32_data(conv_weight);
op->attrs["bias"].set_float32_data(conv_bias);
}
};
......
......@@ -63,10 +63,10 @@ pnnx.Output output 1 0 out
bool has_bn_affine = captured_params.at("affine").b;
bool has_conv_bias = captured_params.at("bias").b;
const float* bn_running_mean = (const float*)captured_attrs.at("op_1.running_mean").data.data();
const float* bn_running_var = (const float*)captured_attrs.at("op_1.running_var").data.data();
const float* bn_weight = has_bn_affine ? (const float*)captured_attrs.at("op_1.weight").data.data() : 0;
const float* bn_bias = has_bn_affine ? (const float*)captured_attrs.at("op_1.bias").data.data() : 0;
auto bn_running_mean = captured_attrs.at("op_1.running_mean").get_float32_data();
auto bn_running_var = captured_attrs.at("op_1.running_var").get_float32_data();
auto bn_weight = has_bn_affine ? captured_attrs.at("op_1.weight").get_float32_data() : std::vector<float>();
auto bn_bias = has_bn_affine ? captured_attrs.at("op_1.bias").get_float32_data() : std::vector<float>();
// a = bias - slope * mean / sqrt(var + eps)
// b = slope / sqrt(var + eps)
......@@ -100,22 +100,20 @@ pnnx.Output output 1 0 out
{
// init bias as zero
op->attrs["bias"] = Attribute();
op->attrs["bias"].type = 1;
op->attrs["bias"].type = op->attrs["weight"].type;
op->attrs["bias"].shape = {channels};
op->attrs["bias"].data.resize(channels * sizeof(float));
memset(op->attrs["bias"].data.data(), 0, channels * sizeof(float));
op->attrs["bias"].set_float32_data(std::vector<float>(channels, 0.f));
}
float* conv_weight = (float*)op->attrs["weight"].data.data();
float* conv_bias = (float*)op->attrs["bias"].data.data();
auto conv_weight = op->attrs["weight"].get_float32_data();
auto conv_bias = op->attrs["bias"].get_float32_data();
const int outch = captured_params.at("out_channels").i;
const int weight_per_outch = op->attrs["weight"].data.size() / sizeof(float) / outch;
const int weight_per_outch = op->attrs["weight"].elemcount() / outch;
for (int i = 0; i < channels; i++)
{
float* conv_weight_outch = conv_weight + weight_per_outch * i;
float* conv_weight_outch = conv_weight.data() + weight_per_outch * i;
for (int j = 0; j < weight_per_outch; j++)
{
conv_weight_outch[j] *= b[i];
......@@ -123,6 +121,9 @@ pnnx.Output output 1 0 out
conv_bias[i] = conv_bias[i] * b[i] + a[i];
}
op->attrs["weight"].set_float32_data(conv_weight);
op->attrs["bias"].set_float32_data(conv_bias);
}
};
......
......@@ -63,10 +63,10 @@ pnnx.Output output 1 0 out
bool has_bn_affine = captured_params.at("affine").b;
bool has_convtranspose_bias = captured_params.at("bias").b;
const float* bn_running_mean = (const float*)captured_attrs.at("op_1.running_mean").data.data();
const float* bn_running_var = (const float*)captured_attrs.at("op_1.running_var").data.data();
const float* bn_weight = has_bn_affine ? (const float*)captured_attrs.at("op_1.weight").data.data() : 0;
const float* bn_bias = has_bn_affine ? (const float*)captured_attrs.at("op_1.bias").data.data() : 0;
auto bn_running_mean = captured_attrs.at("op_1.running_mean").get_float32_data();
auto bn_running_var = captured_attrs.at("op_1.running_var").get_float32_data();
auto bn_weight = has_bn_affine ? captured_attrs.at("op_1.weight").get_float32_data() : std::vector<float>();
auto bn_bias = has_bn_affine ? captured_attrs.at("op_1.bias").get_float32_data() : std::vector<float>();
// a = bias - slope * mean / sqrt(var + eps)
// b = slope / sqrt(var + eps)
......@@ -100,15 +100,13 @@ pnnx.Output output 1 0 out
{
// init bias as zero
op->attrs["bias"] = Attribute();
op->attrs["bias"].type = 1;
op->attrs["bias"].type = op->attrs["weight"].type;
op->attrs["bias"].shape = {channels};
op->attrs["bias"].data.resize(channels * sizeof(float));
memset(op->attrs["bias"].data.data(), 0, channels * sizeof(float));
op->attrs["bias"].set_float32_data(std::vector<float>(channels, 0.f));
}
float* conv_weight = (float*)op->attrs["weight"].data.data();
float* conv_bias = (float*)op->attrs["bias"].data.data();
auto conv_weight = op->attrs["weight"].get_float32_data();
auto conv_bias = op->attrs["bias"].get_float32_data();
// group-inch/group-outch/group-kw
const int inch = captured_params.at("in_channels").i;
......@@ -121,7 +119,7 @@ pnnx.Output output 1 0 out
for (int g = 0; g < groups; g++)
{
float* wg = conv_weight + g * inch_g * outch_g * kw;
float* wg = (float*)conv_weight.data() + g * inch_g * outch_g * kw;
for (int i = 0; i < inch_g; i++)
{
for (int j = 0; j < outch_g; j++)
......@@ -138,6 +136,9 @@ pnnx.Output output 1 0 out
{
conv_bias[i] = conv_bias[i] * b[i] + a[i];
}
op->attrs["weight"].set_float32_data(conv_weight);
op->attrs["bias"].set_float32_data(conv_bias);
}
};
......
......@@ -63,10 +63,10 @@ pnnx.Output output 1 0 out
bool has_bn_affine = captured_params.at("affine").b;
bool has_convtranspose_bias = captured_params.at("bias").b;
const float* bn_running_mean = (const float*)captured_attrs.at("op_1.running_mean").data.data();
const float* bn_running_var = (const float*)captured_attrs.at("op_1.running_var").data.data();
const float* bn_weight = has_bn_affine ? (const float*)captured_attrs.at("op_1.weight").data.data() : 0;
const float* bn_bias = has_bn_affine ? (const float*)captured_attrs.at("op_1.bias").data.data() : 0;
auto bn_running_mean = captured_attrs.at("op_1.running_mean").get_float32_data();
auto bn_running_var = captured_attrs.at("op_1.running_var").get_float32_data();
auto bn_weight = has_bn_affine ? captured_attrs.at("op_1.weight").get_float32_data() : std::vector<float>();
auto bn_bias = has_bn_affine ? captured_attrs.at("op_1.bias").get_float32_data() : std::vector<float>();
// a = bias - slope * mean / sqrt(var + eps)
// b = slope / sqrt(var + eps)
......@@ -100,15 +100,13 @@ pnnx.Output output 1 0 out
{
// init bias as zero
op->attrs["bias"] = Attribute();
op->attrs["bias"].type = 1;
op->attrs["bias"].type = op->attrs["weight"].type;
op->attrs["bias"].shape = {channels};
op->attrs["bias"].data.resize(channels * sizeof(float));
memset(op->attrs["bias"].data.data(), 0, channels * sizeof(float));
op->attrs["bias"].set_float32_data(std::vector<float>(channels, 0.f));
}
float* conv_weight = (float*)op->attrs["weight"].data.data();
float* conv_bias = (float*)op->attrs["bias"].data.data();
auto conv_weight = op->attrs["weight"].get_float32_data();
auto conv_bias = op->attrs["bias"].get_float32_data();
// group-inch/group-outch/group-kh-kw
const int inch = captured_params.at("in_channels").i;
......@@ -123,7 +121,7 @@ pnnx.Output output 1 0 out
for (int g = 0; g < groups; g++)
{
float* wg = conv_weight + g * inch_g * outch_g * maxk;
float* wg = (float*)conv_weight.data() + g * inch_g * outch_g * maxk;
for (int i = 0; i < inch_g; i++)
{
for (int j = 0; j < outch_g; j++)
......@@ -140,6 +138,9 @@ pnnx.Output output 1 0 out
{
conv_bias[i] = conv_bias[i] * b[i] + a[i];
}
op->attrs["weight"].set_float32_data(conv_weight);
op->attrs["bias"].set_float32_data(conv_bias);
}
};
......
......@@ -57,10 +57,10 @@ pnnx.Output output 1 0 out
bool has_bn_affine = captured_params.at("affine").b;
bool has_conv_bias = captured_params.at("bias").b;
const float* bn_running_mean = (const float*)captured_attrs.at("op_1.running_mean").data.data();
const float* bn_running_var = (const float*)captured_attrs.at("op_1.running_var").data.data();
const float* bn_weight = has_bn_affine ? (const float*)captured_attrs.at("op_1.weight").data.data() : 0;
const float* bn_bias = has_bn_affine ? (const float*)captured_attrs.at("op_1.bias").data.data() : 0;
auto bn_running_mean = captured_attrs.at("op_1.running_mean").get_float32_data();
auto bn_running_var = captured_attrs.at("op_1.running_var").get_float32_data();
auto bn_weight = has_bn_affine ? captured_attrs.at("op_1.weight").get_float32_data() : std::vector<float>();
auto bn_bias = has_bn_affine ? captured_attrs.at("op_1.bias").get_float32_data() : std::vector<float>();
// a = bias - slope * mean / sqrt(var + eps)
// b = slope / sqrt(var + eps)
......@@ -94,21 +94,19 @@ pnnx.Output output 1 0 out
{
// init bias as zero
op->attrs["bias"] = Attribute();
op->attrs["bias"].type = 1;
op->attrs["bias"].type = op->attrs["weight"].type;
op->attrs["bias"].shape = {channels};
op->attrs["bias"].data.resize(channels * sizeof(float));
memset(op->attrs["bias"].data.data(), 0, channels * sizeof(float));
op->attrs["bias"].set_float32_data(std::vector<float>(channels, 0.f));
}
float* conv_weight = (float*)op->attrs["weight"].data.data();
float* conv_bias = (float*)op->attrs["bias"].data.data();
auto conv_weight = op->attrs["weight"].get_float32_data();
auto conv_bias = op->attrs["bias"].get_float32_data();
const int weight_per_outch = op->params["in_features"].i;
for (int i = 0; i < channels; i++)
{
float* conv_weight_outch = conv_weight + weight_per_outch * i;
float* conv_weight_outch = (float*)conv_weight.data() + weight_per_outch * i;
for (int j = 0; j < weight_per_outch; j++)
{
conv_weight_outch[j] *= b[i];
......@@ -116,6 +114,9 @@ pnnx.Output output 1 0 out
conv_bias[i] = conv_bias[i] * b[i] + a[i];
}
op->attrs["weight"].set_float32_data(conv_weight);
op->attrs["bias"].set_float32_data(conv_bias);
}
};
......
......@@ -132,11 +132,9 @@ pnnx.Output output 1 0 out
{
// init bias as zero
op->attrs["in_proj_bias"] = Attribute();
op->attrs["in_proj_bias"].type = 1;
op->attrs["in_proj_bias"].type = op->attrs["in_proj_weight"].type;
op->attrs["in_proj_bias"].shape = {embed_dim * 3};
op->attrs["in_proj_bias"].data.resize(embed_dim * 3 * sizeof(float));
memset(op->attrs["in_proj_bias"].data.data(), 0, embed_dim * 3 * sizeof(float));
op->attrs["in_proj_bias"].set_float32_data(std::vector<float>(embed_dim * 3, 0.f));
}
}
......@@ -151,11 +149,9 @@ pnnx.Output output 1 0 out
{
// init bias as zero
op->attrs["out_proj.bias"] = Attribute();
op->attrs["out_proj.bias"].type = 1;
op->attrs["out_proj.bias"].type = op->attrs["out_proj.weight"].type;
op->attrs["out_proj.bias"].shape = {embed_dim};
op->attrs["out_proj.bias"].data.resize(embed_dim * sizeof(float));
memset(op->attrs["out_proj.bias"].data.data(), 0, embed_dim * sizeof(float));
op->attrs["out_proj.bias"].set_float32_data(std::vector<float>(embed_dim, 0.f));
}
}
}
......@@ -337,36 +333,23 @@ pnnx.Output output 1 0 out
op->params["add_bias_kv"] = false;
op->params["bias"] = bias;
op->attrs["in_proj_weight"] = Attribute();
op->attrs["in_proj_weight"].type = 1;
op->attrs["in_proj_weight"].shape = {embed_dim * 3, embed_dim};
op->attrs["in_proj_weight"].data.resize(embed_dim * 3 * embed_dim * sizeof(float));
// combine qkv weight
{
float* in_proj_weight_ptr = (float*)op->attrs["in_proj_weight"].data.data();
memcpy(in_proj_weight_ptr, captured_attrs.at("op_0.weight").data.data(), embed_dim * embed_dim * sizeof(float));
in_proj_weight_ptr += embed_dim * embed_dim;
memcpy(in_proj_weight_ptr, captured_attrs.at("op_1.weight").data.data(), embed_dim * embed_dim * sizeof(float));
in_proj_weight_ptr += embed_dim * embed_dim;
memcpy(in_proj_weight_ptr, captured_attrs.at("op_2.weight").data.data(), embed_dim * embed_dim * sizeof(float));
}
op->attrs["in_proj_weight"] = captured_attrs.at("op_0.weight") + captured_attrs.at("op_1.weight") + captured_attrs.at("op_2.weight");
op->attrs["out_proj.weight"] = captured_attrs.at("out_proj.weight");
if (bias)
{
op->attrs["in_proj_bias"] = Attribute();
op->attrs["in_proj_bias"].type = 1;
op->attrs["in_proj_bias"].type = op->attrs["in_proj_weight"].type;
op->attrs["in_proj_bias"].shape = {embed_dim * 3};
op->attrs["in_proj_bias"].data.resize(embed_dim * 3 * sizeof(float));
// combine qkv bias
std::vector<float> in_proj_bias(embed_dim * 3);
{
float* in_proj_bias_ptr = (float*)op->attrs["in_proj_bias"].data.data();
float* in_proj_bias_ptr = (float*)in_proj_bias.data();
if (q_bias)
{
memcpy(in_proj_bias_ptr, captured_attrs.at("op_0.bias").data.data(), embed_dim * sizeof(float));
auto qb = captured_attrs.at("op_0.bias").get_float32_data();
memcpy(in_proj_bias_ptr, (const void*)qb.data(), embed_dim * sizeof(float));
}
else
{
......@@ -375,7 +358,8 @@ pnnx.Output output 1 0 out
in_proj_bias_ptr += embed_dim;
if (k_bias)
{
memcpy(in_proj_bias_ptr, captured_attrs.at("op_1.bias").data.data(), embed_dim * sizeof(float));
auto kb = captured_attrs.at("op_1.bias").get_float32_data();
memcpy(in_proj_bias_ptr, (const void*)kb.data(), embed_dim * sizeof(float));
}
else
{
......@@ -384,13 +368,15 @@ pnnx.Output output 1 0 out
in_proj_bias_ptr += embed_dim;
if (v_bias)
{
memcpy(in_proj_bias_ptr, captured_attrs.at("op_2.bias").data.data(), embed_dim * sizeof(float));
auto vb = captured_attrs.at("op_2.bias").get_float32_data();
memcpy(in_proj_bias_ptr, (const void*)vb.data(), embed_dim * sizeof(float));
}
else
{
memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float));
}
}
op->attrs["in_proj_bias"].set_float32_data(in_proj_bias);
if (out_bias)
{
......@@ -400,11 +386,9 @@ pnnx.Output output 1 0 out
{
// init bias as zero
op->attrs["out_proj.bias"] = Attribute();
op->attrs["out_proj.bias"].type = 1;
op->attrs["out_proj.bias"].type = op->attrs["out_proj.weight"].type;
op->attrs["out_proj.bias"].shape = {embed_dim};
op->attrs["out_proj.bias"].data.resize(embed_dim * sizeof(float));
memset(op->attrs["out_proj.bias"].data.data(), 0, embed_dim * sizeof(float));
op->attrs["out_proj.bias"].set_float32_data(std::vector<float>(embed_dim, 0.f));
}
}
}
......@@ -536,16 +520,16 @@ pnnx.Output output 1 0 out
if (bias)
{
op->attrs["in_proj_bias"] = Attribute();
op->attrs["in_proj_bias"].type = 1;
op->attrs["in_proj_bias"].type = op->attrs["q_proj_weight"].type;
op->attrs["in_proj_bias"].shape = {embed_dim * 3};
op->attrs["in_proj_bias"].data.resize(embed_dim * 3 * sizeof(float));
// combine qkv bias
std::vector<float> in_proj_bias(embed_dim * 3);
{
float* in_proj_bias_ptr = (float*)op->attrs["in_proj_bias"].data.data();
float* in_proj_bias_ptr = (float*)in_proj_bias.data();
if (q_bias)
{
memcpy(in_proj_bias_ptr, captured_attrs.at("op_0.bias").data.data(), embed_dim * sizeof(float));
auto qb = captured_attrs.at("op_0.bias").get_float32_data();
memcpy(in_proj_bias_ptr, (const void*)qb.data(), embed_dim * sizeof(float));
}
else
{
......@@ -554,7 +538,8 @@ pnnx.Output output 1 0 out
in_proj_bias_ptr += embed_dim;
if (k_bias)
{
memcpy(in_proj_bias_ptr, captured_attrs.at("op_1.bias").data.data(), embed_dim * sizeof(float));
auto kb = captured_attrs.at("op_1.bias").get_float32_data();
memcpy(in_proj_bias_ptr, (const void*)kb.data(), embed_dim * sizeof(float));
}
else
{
......@@ -563,13 +548,15 @@ pnnx.Output output 1 0 out
in_proj_bias_ptr += embed_dim;
if (v_bias)
{
memcpy(in_proj_bias_ptr, captured_attrs.at("op_2.bias").data.data(), embed_dim * sizeof(float));
auto vb = captured_attrs.at("op_2.bias").get_float32_data();
memcpy(in_proj_bias_ptr, (const void*)vb.data(), embed_dim * sizeof(float));
}
else
{
memset(in_proj_bias_ptr, 0, embed_dim * sizeof(float));
}
}
op->attrs["in_proj_bias"].set_float32_data(in_proj_bias);
if (out_bias)
{
......@@ -579,11 +566,9 @@ pnnx.Output output 1 0 out
{
// init bias as zero
op->attrs["out_proj.bias"] = Attribute();
op->attrs["out_proj.bias"].type = 1;
op->attrs["out_proj.bias"].type = op->attrs["out_proj.weight"].type;
op->attrs["out_proj.bias"].shape = {embed_dim};
op->attrs["out_proj.bias"].data.resize(embed_dim * sizeof(float));
memset(op->attrs["out_proj.bias"].data.data(), 0, embed_dim * sizeof(float));
op->attrs["out_proj.bias"].set_float32_data(std::vector<float>(embed_dim, 0.f));
}
}
}
......@@ -1284,15 +1269,223 @@ pnnx.Output output 1 0 out
if (attn_mask->consumers.size() > 1 || attn_mask->producer->type != "pnnx.Attribute")
return false;
return true;
}
void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
fuse_multiheadattention_pass_sameqkv::write(op, captured_params, captured_attrs);
Operand* attn_mask = op->inputs[1];
Operator* op_attr = attn_mask->producer;
// hack attn_mask shape
attn_mask->shape = std::vector<int>{attn_mask->shape[2], attn_mask->shape[3]};
const std::string key = op_attr->attrs.begin()->first;
op_attr->attrs[key].shape = attn_mask->shape;
}
};
class fuse_multiheadattention_pass_17 : public fuse_multiheadattention_pass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
17 16
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 attn_mask
nn.Linear op_0 1 1 input 8 bias=%qkv_bias in_features=%embed_dim out_features=%qkv_out_features @bias @weight
Tensor.reshape op_1 1 1 8 9 shape=%shape
torch.permute op_2 1 1 9 10 dims=(2,0,3,1,4)
torch.unbind op_3 1 3 10 11 12 13 dim=0
pnnx.Expression op_4 1 1 11 14 expr=%expr
torch.transpose op_5 1 1 12 15 dim0=-2 dim1=-1
torch.matmul op_6 2 1 14 15 16
pnnx.Expression op_7 2 1 16 attn_mask 18 expr=%expr2
F.softmax op_8 1 1 18 19 dim=-1
torch.matmul op_9 2 1 19 13 20
torch.transpose op_10 1 1 20 21 dim0=1 dim1=2
Tensor.reshape op_11 1 1 21 22 shape=%shape2
nn.Linear out_proj 1 1 22 out bias=%out_proj_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
bool match(const std::map<std::string, Parameter>& captured_params) const
{
bool matched = fuse_multiheadattention_pass::match(captured_params);
if (!matched)
return false;
if (captured_params.at("expr2").s != "add(@0,@1)")
return false;
return true;
}
bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
const Operator* op_7 = matched_operators.at("op_7");
// support constant attention mask only atm
Operand* attn_mask = op_7->inputs[1];
if (attn_mask->consumers.size() > 1 || attn_mask->producer->type != "pnnx.Attribute")
return false;
return true;
}
void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
fuse_multiheadattention_pass::write(op, captured_params, captured_attrs);
Operand* attn_mask = op->inputs[1];
Operator* op_attr = attn_mask->producer;
int batch = op->inputs[0]->shape[0];
// hack attn_mask shape
attn_mask->shape = std::vector<int>{batch * attn_mask->shape[1], attn_mask->shape[2], attn_mask->shape[3]};
const std::string key = op_attr->attrs.begin()->first;
op_attr->attrs[key].shape = attn_mask->shape;
// hack attn_mask value
std::vector<char>& data = op_attr->attrs[key].data;
size_t len = data.size();
data.resize(len * batch);
for (int i = 1; i < batch; i++)
{
memcpy(&data[len * i], &data[0], len);
}
}
};
class fuse_multiheadattention_pass_18 : public fuse_multiheadattention_pass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
20 19
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 attn_mask
nn.Linear op_0 1 1 input 25 bias=%qkv_bias in_features=%embed_dim out_features=%qkv_out_features @bias @weight
Tensor.reshape op_1 1 1 25 26 shape=%shape
torch.permute op_2 1 1 26 27 dims=(2,0,3,1,4)
torch.unbind op_3 1 3 27 28 29 30 dim=0
pnnx.Expression op_4 1 1 28 31 expr=%expr
torch.transpose op_5 1 1 29 32 dim0=-2 dim1=-1
torch.matmul op_6 2 1 31 32 33
pnnx.Expression op_7 2 1 33 attn_mask 35 expr=%expr2
Tensor.view op_8 1 1 35 36 shape=%shapep
pnnx.Attribute op_9 0 1 37 @mask2
pnnx.Expression op_10 2 1 36 37 38 expr=%expr2
Tensor.view op_11 1 1 38 39 shape=%shapeq
F.softmax op_12 1 1 39 40 dim=-1
torch.matmul op_13 2 1 40 30 41
torch.transpose op_14 1 1 41 42 dim0=1 dim1=2
Tensor.reshape op_15 1 1 42 43 shape=%shape2
nn.Linear out_proj 1 1 43 out bias=%out_proj_bias in_features=%embed_dim out_features=%embed_dim @bias @weight
pnnx.Output output 1 0 out
)PNNXIR";
}
bool match(const std::map<std::string, Parameter>& captured_params) const
{
bool matched = fuse_multiheadattention_pass::match(captured_params);
if (!matched)
return false;
if (captured_params.at("expr2").s != "add(@0,@1)")
return false;
// (1,64,3,49,49)
// (-1,3,49,49)
const std::vector<int>& shapep = captured_params.at("shapep").ai;
const std::vector<int>& shapeq = captured_params.at("shapeq").ai;
if (shapep.size() != 5 || shapeq.size() != 4)
return false;
if (shapep[0] != 1 || (shapep[1] != shapeq[0] && shapeq[0] != -1) || shapep[2] != shapeq[1] || shapep[3] != shapeq[2] || shapep[4] != shapeq[3])
return false;
return true;
}
bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
const Operator* op_7 = matched_operators.at("op_7");
// support constant attention mask only atm
Operand* attn_mask = op_7->inputs[1];
if (attn_mask->consumers.size() > 1 || attn_mask->producer->type != "pnnx.Attribute")
return false;
// @mask2=(1,64,1,49,49)f32
if (attn_mask->shape.size() != 5)
return false;
if (attn_mask->shape[0] != 1 || attn_mask->shape[2] != 1)
return false;
return true;
}
void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
fuse_multiheadattention_pass::write(op, captured_params, captured_attrs);
int num_heads = captured_params.at("shape").ai[captured_params.at("shape").ai.size() - 2];
Operand* attn_mask = op->inputs[1];
Operator* op_attr = attn_mask->producer;
// @mask2=(1,64,1,49,49)f32
Attribute mask2;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 5) == "op_9.")
mask2 = x.second;
}
int batch = op->inputs[0]->shape[0];
// hack attn_mask shape
attn_mask->shape = std::vector<int>{batch * attn_mask->shape[1], attn_mask->shape[2], attn_mask->shape[3]};
const std::string key = op_attr->attrs.begin()->first;
op_attr->attrs[key].shape = attn_mask->shape;
// hack attn_mask value
std::vector<char>& data = op_attr->attrs[key].data;
size_t len = data.size();
data.resize(len * batch);
for (int i = 1; i < batch; i++)
{
memcpy(&data[len * i], &data[0], len);
}
// add mask2
{
auto maskdata = op_attr->attrs[key].get_float32_data();
const int ls = mask2.shape[3] * mask2.shape[4];
for (int i = 0; i < batch; i++)
{
for (int n = 0; n < num_heads; n++)
{
float* p = (float*)maskdata.data() + ls * (i * num_heads + n);
const float* p2 = ((float*)mask2.data.data()) + ls * i;
for (int k = 0; k < ls; k++)
{
p[k] += p2[k];
}
}
}
op_attr->attrs[key].set_float32_data(maskdata);
}
}
};
void fuse_multiheadattention(Graph& graph)
......@@ -1318,6 +1511,8 @@ void fuse_multiheadattention(Graph& graph)
fuse_multiheadattention_pass_14 m;
fuse_multiheadattention_pass_15 n;
fuse_multiheadattention_pass_16 o;
fuse_multiheadattention_pass_17 p;
fuse_multiheadattention_pass_18 q;
int opindex = 0;
pnnx_graph_rewrite(graph, &a, opindex);
......@@ -1340,6 +1535,8 @@ void fuse_multiheadattention(Graph& graph)
pnnx_graph_rewrite(graph, &m, opindex);
pnnx_graph_rewrite(graph, &n, opindex);
pnnx_graph_rewrite(graph, &o, opindex);
pnnx_graph_rewrite(graph, &p, opindex);
pnnx_graph_rewrite(graph, &q, opindex);
#endif
}
......
......@@ -54,7 +54,7 @@ pnnx.Output output 1 0 out
op->params["0"] = weight.shape[1];
op->params["1"] = weight.shape[0];
op->params["2"] = 0;
op->params["3"] = (int)(weight.data.size() / sizeof(float));
op->params["3"] = weight.elemcount();
op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
......
......@@ -14,61 +14,12 @@
#include "convert_half_to_float.h"
#include <string.h>
namespace pnnx {
namespace ncnn {
static float float16_to_float32(unsigned short value)
{
// 1 : 5 : 10
unsigned short sign = (value & 0x8000) >> 15;
unsigned short exponent = (value & 0x7c00) >> 10;
unsigned short significand = value & 0x03FF;
// NCNN_LOGE("%d %d %d", sign, exponent, significand);
// 1 : 8 : 23
union
{
unsigned int u;
float f;
} tmp;
if (exponent == 0)
{
if (significand == 0)
{
// zero
tmp.u = (sign << 31);
}
else
{
// denormal
exponent = 0;
// find non-zero bit
while ((significand & 0x200) == 0)
{
significand <<= 1;
exponent++;
}
significand <<= 1;
significand &= 0x3FF;
tmp.u = (sign << 31) | ((-exponent + (-15 + 127)) << 23) | (significand << 13);
}
}
else if (exponent == 0x1F)
{
// infinity or NaN
tmp.u = (sign << 31) | (0xFF << 23) | (significand << 13);
}
else
{
// normalized
tmp.u = (sign << 31) | ((exponent + (-15 + 127)) << 23) | (significand << 13);
}
return tmp.f;
}
void convert_half_to_float(Graph& graph)
{
for (Operator* op : graph.ops)
......@@ -89,15 +40,10 @@ void convert_half_to_float(Graph& graph)
Attribute attr_new;
attr_new.type = 1;
attr_new.shape = attr.shape;
attr_new.data.resize(attr.data.size() * 2);
attr_new.data.resize(attr.elemcount() * 4);
const unsigned short* p = (const unsigned short*)attr.data.data();
float* outp = (float*)attr_new.data.data();
int len = attr_new.data.size() / 4;
for (int i = 0; i < len; i++)
{
outp[i] = float16_to_float32(p[i]);
}
auto p = attr.get_float32_data();
memcpy((void*)attr_new.data.data(), (const void*)p.data(), attr_new.data.size());
op->attrs[x.first] = attr_new;
......
......@@ -65,7 +65,7 @@ pnnx.Output output 1 0 out
op->params["4"] = captured_params.at("padding").ai[0];
}
op->params["5"] = captured_params.at("bias").b ? 1 : 0;
op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["6"] = captured_attrs.at("op_0.weight").elemcount();
op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
......@@ -122,7 +122,7 @@ pnnx.Output output 1 0 out
op->params["4"] = captured_params.at("padding").ai[0];
}
op->params["5"] = captured_params.at("bias").b ? 1 : 0;
op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["6"] = captured_attrs.at("op_0.weight").elemcount();
op->params["7"] = captured_params.at("groups");
op->attrs["0"] = Attribute();
......
......@@ -69,7 +69,7 @@ pnnx.Output output 1 0 out
op->params["14"] = captured_params.at("padding").ai[0];
}
op->params["5"] = captured_params.at("bias").b ? 1 : 0;
op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["6"] = captured_attrs.at("op_0.weight").elemcount();
op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
......@@ -130,7 +130,7 @@ pnnx.Output output 1 0 out
op->params["14"] = captured_params.at("padding").ai[0];
}
op->params["5"] = captured_params.at("bias").b ? 1 : 0;
op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["6"] = captured_attrs.at("op_0.weight").elemcount();
op->params["7"] = captured_params.at("groups");
op->attrs["0"] = Attribute();
......
......@@ -73,7 +73,7 @@ pnnx.Output output 1 0 out
op->params["24"] = captured_params.at("padding").ai[0];
}
op->params["5"] = captured_params.at("bias").b ? 1 : 0;
op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["6"] = captured_attrs.at("op_0.weight").elemcount();
op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
......@@ -138,7 +138,7 @@ pnnx.Output output 1 0 out
op->params["24"] = captured_params.at("padding").ai[0];
}
op->params["5"] = captured_params.at("bias").b ? 1 : 0;
op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["6"] = captured_attrs.at("op_0.weight").elemcount();
op->params["7"] = captured_params.at("groups");
op->attrs["0"] = Attribute();
......
......@@ -50,7 +50,7 @@ pnnx.Output output 1 0 out
op->params["4"] = captured_params.at("padding").ai[0];
op->params["18"] = captured_params.at("output_padding").ai[0];
op->params["5"] = captured_params.at("bias").b ? 1 : 0;
op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["6"] = captured_attrs.at("op_0.weight").elemcount();
// transpose inch-outch-kw to outch-inch-kw
const int inch = captured_params.at("in_channels").i;
......@@ -58,7 +58,7 @@ pnnx.Output output 1 0 out
const int kw = captured_params.at("kernel_size").ai[0];
std::vector<float> new_weight;
{
const float* w = (const float*)captured_attrs.at("op_0.weight").data.data();
auto w = captured_attrs.at("op_0.weight").get_float32_data();
new_weight.resize(outch * inch * kw);
float* w2 = (float*)new_weight.data();
......@@ -116,7 +116,7 @@ pnnx.Output output 1 0 out
op->params["4"] = captured_params.at("padding").ai[0];
op->params["18"] = captured_params.at("output_padding").ai[0];
op->params["5"] = captured_params.at("bias").b ? 1 : 0;
op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["6"] = captured_attrs.at("op_0.weight").elemcount();
op->params["7"] = captured_params.at("groups");
// transpose group-inch/group-outch/group-kw to group-outch/group-inch/group-kw
......@@ -126,7 +126,7 @@ pnnx.Output output 1 0 out
const int kw = captured_params.at("kernel_size").ai[0];
std::vector<float> new_weight;
{
const float* w = (const float*)captured_attrs.at("op_0.weight").data.data();
auto w = captured_attrs.at("op_0.weight").get_float32_data();
new_weight.resize(outch / groups * inch * kw);
float* w2 = (float*)new_weight.data();
......@@ -137,7 +137,7 @@ pnnx.Output output 1 0 out
{
// reorder weight from inch-outch to outch-inch
float* wg2 = w2 + g * outch_g * inch_g * kw;
const float* wg = w + g * inch_g * outch_g * kw;
const float* wg = (const float*)w.data() + g * inch_g * outch_g * kw;
for (int i = 0; i < outch_g; i++)
{
for (int j = 0; j < inch_g; j++)
......
......@@ -55,7 +55,7 @@ pnnx.Output output 1 0 out
op->params["18"] = captured_params.at("output_padding").ai[1];
op->params["19"] = captured_params.at("output_padding").ai[0];
op->params["5"] = captured_params.at("bias").b ? 1 : 0;
op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["6"] = captured_attrs.at("op_0.weight").elemcount();
// transpose inch-outch-kh-kw to outch-inch-kh-kw
const int inch = captured_params.at("in_channels").i;
......@@ -64,7 +64,7 @@ pnnx.Output output 1 0 out
const int kw = captured_params.at("kernel_size").ai[1];
std::vector<float> new_weight;
{
const float* w = (const float*)captured_attrs.at("op_0.weight").data.data();
auto w = captured_attrs.at("op_0.weight").get_float32_data();
new_weight.resize(outch * inch * kh * kw);
float* w2 = (float*)new_weight.data();
......@@ -128,7 +128,7 @@ pnnx.Output output 1 0 out
op->params["18"] = captured_params.at("output_padding").ai[1];
op->params["19"] = captured_params.at("output_padding").ai[0];
op->params["5"] = captured_params.at("bias").b ? 1 : 0;
op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["6"] = captured_attrs.at("op_0.weight").elemcount();
op->params["7"] = captured_params.at("groups");
// transpose group-inch/group-outch/group-kh-kw to group-outch/group-inch/group-kh-kw
......@@ -139,7 +139,7 @@ pnnx.Output output 1 0 out
const int kw = captured_params.at("kernel_size").ai[1];
std::vector<float> new_weight;
{
const float* w = (const float*)captured_attrs.at("op_0.weight").data.data();
auto w = captured_attrs.at("op_0.weight").get_float32_data();
new_weight.resize(outch / groups * inch * kh * kw);
float* w2 = (float*)new_weight.data();
......@@ -151,7 +151,7 @@ pnnx.Output output 1 0 out
{
// reorder weight from inch-outch to outch-inch
float* wg2 = w2 + g * outch_g * inch_g * maxk;
const float* wg = w + g * inch_g * outch_g * maxk;
const float* wg = (const float*)w.data() + g * inch_g * outch_g * maxk;
for (int i = 0; i < outch_g; i++)
{
for (int j = 0; j < inch_g; j++)
......
......@@ -60,7 +60,7 @@ pnnx.Output output 1 0 out
op->params["19"] = captured_params.at("output_padding").ai[1];
op->params["20"] = captured_params.at("output_padding").ai[0];
op->params["5"] = captured_params.at("bias").b ? 1 : 0;
op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["6"] = captured_attrs.at("op_0.weight").elemcount();
// transpose inch-outch-kd-kh-kw to outch-inch-kd-kh-kw
const int inch = captured_params.at("in_channels").i;
......@@ -70,7 +70,7 @@ pnnx.Output output 1 0 out
const int kw = captured_params.at("kernel_size").ai[2];
std::vector<float> new_weight;
{
const float* w = (const float*)captured_attrs.at("op_0.weight").data.data();
auto w = captured_attrs.at("op_0.weight").get_float32_data();
new_weight.resize(outch * inch * kd * kh * kw);
float* w2 = (float*)new_weight.data();
......@@ -139,7 +139,7 @@ pnnx.Output output 1 0 out
op->params["19"] = captured_params.at("output_padding").ai[1];
op->params["20"] = captured_params.at("output_padding").ai[0];
op->params["5"] = captured_params.at("bias").b ? 1 : 0;
op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["6"] = captured_attrs.at("op_0.weight").elemcount();
op->params["7"] = captured_params.at("groups");
// transpose group-inch/group-outch/group-kd-kh-kw to group-outch/group-inch/group-kd-kh-kw
......@@ -151,7 +151,7 @@ pnnx.Output output 1 0 out
const int kw = captured_params.at("kernel_size").ai[2];
std::vector<float> new_weight;
{
const float* w = (const float*)captured_attrs.at("op_0.weight").data.data();
auto w = captured_attrs.at("op_0.weight").get_float32_data();
new_weight.resize(outch / groups * inch * kd * kh * kw);
float* w2 = (float*)new_weight.data();
......@@ -163,7 +163,7 @@ pnnx.Output output 1 0 out
{
// reorder weight from inch-outch to outch-inch
float* wg2 = w2 + g * outch_g * inch_g * maxk;
const float* wg = w + g * inch_g * outch_g * maxk;
const float* wg = (const float*)w.data() + g * inch_g * outch_g * maxk;
for (int i = 0; i < outch_g; i++)
{
for (int j = 0; j < inch_g; j++)
......
......@@ -46,7 +46,7 @@ pnnx.Output output 1 0 out
op->params["0"] = captured_params.at("embedding_dim");
op->params["1"] = captured_params.at("num_embeddings");
op->params["2"] = 0;
op->params["3"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["3"] = captured_attrs.at("op_0.weight").elemcount();
op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
......
......@@ -73,8 +73,8 @@ pnnx.Output output 2 0 out out_hidden
// reduce bias_ih and bias_hh
std::vector<float> new_bias;
{
const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0").data.data();
const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0").data.data();
auto bias_ih = captured_attrs.at("op_0.bias_ih_l0").get_float32_data();
auto bias_hh = captured_attrs.at("op_0.bias_hh_l0").get_float32_data();
new_bias.resize(4 * num_output);
float* bias = (float*)new_bias.data();
......@@ -82,16 +82,16 @@ pnnx.Output output 2 0 out out_hidden
{
bias[i] = bias_ih[i] + bias_hh[i];
}
memcpy(bias + num_output * 2, bias_ih + num_output * 2, num_output * sizeof(float));
memcpy(bias + num_output * 3, bias_hh + num_output * 2, num_output * sizeof(float));
memcpy(bias + num_output * 2, (const float*)bias_ih.data() + num_output * 2, num_output * sizeof(float));
memcpy(bias + num_output * 3, (const float*)bias_hh.data() + num_output * 2, num_output * sizeof(float));
}
if (bidirectional)
{
std::vector<float> new_bias_reverse;
{
const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0_reverse").data.data();
const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0_reverse").data.data();
auto bias_ih = captured_attrs.at("op_0.bias_ih_l0_reverse").get_float32_data();
auto bias_hh = captured_attrs.at("op_0.bias_hh_l0_reverse").get_float32_data();
new_bias_reverse.resize(4 * num_output);
float* bias = (float*)new_bias_reverse.data();
......@@ -99,8 +99,8 @@ pnnx.Output output 2 0 out out_hidden
{
bias[i] = bias_ih[i] + bias_hh[i];
}
memcpy(bias + num_output * 2, bias_ih + num_output * 2, num_output * sizeof(float));
memcpy(bias + num_output * 3, bias_hh + num_output * 2, num_output * sizeof(float));
memcpy(bias + num_output * 2, (const float*)bias_ih.data() + num_output * 2, num_output * sizeof(float));
memcpy(bias + num_output * 3, (const float*)bias_hh.data() + num_output * 2, num_output * sizeof(float));
}
op->attrs["3"] = Attribute({4, num_output}, new_bias) + Attribute({4, num_output}, new_bias_reverse);
......
......@@ -69,11 +69,11 @@ pnnx.Output output 3 0 out out_hidden out_cell
{
const int weight_data_size_g = hidden_size * input_size;
const float* weight_ih = (const float*)captured_attrs.at("op_0.weight_ih_l0").data.data();
const float* iptr = weight_ih;
const float* fptr = weight_ih + weight_data_size_g;
const float* gptr = weight_ih + weight_data_size_g * 2;
const float* optr = weight_ih + weight_data_size_g * 3;
auto weight_ih = captured_attrs.at("op_0.weight_ih_l0").get_float32_data();
const float* iptr = (const float*)weight_ih.data();
const float* fptr = (const float*)weight_ih.data() + weight_data_size_g;
const float* gptr = (const float*)weight_ih.data() + weight_data_size_g * 2;
const float* optr = (const float*)weight_ih.data() + weight_data_size_g * 3;
new_weight_ih.resize(4 * hidden_size * input_size);
float* weight = (float*)new_weight_ih.data();
......@@ -93,11 +93,11 @@ pnnx.Output output 3 0 out out_hidden out_cell
{
const int weight_data_size_g = hidden_size * input_size;
const float* weight_ih = (const float*)captured_attrs.at("op_0.weight_ih_l0_reverse").data.data();
const float* iptr = weight_ih;
const float* fptr = weight_ih + weight_data_size_g;
const float* gptr = weight_ih + weight_data_size_g * 2;
const float* optr = weight_ih + weight_data_size_g * 3;
auto weight_ih = captured_attrs.at("op_0.weight_ih_l0_reverse").get_float32_data();
const float* iptr = (const float*)weight_ih.data();
const float* fptr = (const float*)weight_ih.data() + weight_data_size_g;
const float* gptr = (const float*)weight_ih.data() + weight_data_size_g * 2;
const float* optr = (const float*)weight_ih.data() + weight_data_size_g * 3;
new_weight_ih_reverse.resize(4 * hidden_size * input_size);
float* weight = (float*)new_weight_ih_reverse.data();
......@@ -126,16 +126,16 @@ pnnx.Output output 3 0 out out_hidden out_cell
// reorder IFGO-hidden to IFOG-hidden
std::vector<float> new_bias;
{
const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0").data.data();
const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0").data.data();
const float* bias_ih_iptr = bias_ih;
const float* bias_ih_fptr = bias_ih + hidden_size;
const float* bias_ih_gptr = bias_ih + hidden_size * 2;
const float* bias_ih_optr = bias_ih + hidden_size * 3;
const float* bias_hh_iptr = bias_hh;
const float* bias_hh_fptr = bias_hh + hidden_size;
const float* bias_hh_gptr = bias_hh + hidden_size * 2;
const float* bias_hh_optr = bias_hh + hidden_size * 3;
auto bias_ih = captured_attrs.at("op_0.bias_ih_l0").get_float32_data();
auto bias_hh = captured_attrs.at("op_0.bias_hh_l0").get_float32_data();
const float* bias_ih_iptr = (const float*)bias_ih.data();
const float* bias_ih_fptr = (const float*)bias_ih.data() + hidden_size;
const float* bias_ih_gptr = (const float*)bias_ih.data() + hidden_size * 2;
const float* bias_ih_optr = (const float*)bias_ih.data() + hidden_size * 3;
const float* bias_hh_iptr = (const float*)bias_hh.data();
const float* bias_hh_fptr = (const float*)bias_hh.data() + hidden_size;
const float* bias_hh_gptr = (const float*)bias_hh.data() + hidden_size * 2;
const float* bias_hh_optr = (const float*)bias_hh.data() + hidden_size * 3;
new_bias.resize(4 * hidden_size);
float* bias = (float*)new_bias.data();
......@@ -165,16 +165,16 @@ pnnx.Output output 3 0 out out_hidden out_cell
{
std::vector<float> new_bias_reverse;
{
const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0_reverse").data.data();
const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0_reverse").data.data();
const float* bias_ih_iptr = bias_ih;
const float* bias_ih_fptr = bias_ih + hidden_size;
const float* bias_ih_gptr = bias_ih + hidden_size * 2;
const float* bias_ih_optr = bias_ih + hidden_size * 3;
const float* bias_hh_iptr = bias_hh;
const float* bias_hh_fptr = bias_hh + hidden_size;
const float* bias_hh_gptr = bias_hh + hidden_size * 2;
const float* bias_hh_optr = bias_hh + hidden_size * 3;
auto bias_ih = captured_attrs.at("op_0.bias_ih_l0_reverse").get_float32_data();
auto bias_hh = captured_attrs.at("op_0.bias_hh_l0_reverse").get_float32_data();
const float* bias_ih_iptr = (const float*)bias_ih.data();
const float* bias_ih_fptr = (const float*)bias_ih.data() + hidden_size;
const float* bias_ih_gptr = (const float*)bias_ih.data() + hidden_size * 2;
const float* bias_ih_optr = (const float*)bias_ih.data() + hidden_size * 3;
const float* bias_hh_iptr = (const float*)bias_hh.data();
const float* bias_hh_fptr = (const float*)bias_hh.data() + hidden_size;
const float* bias_hh_gptr = (const float*)bias_hh.data() + hidden_size * 2;
const float* bias_hh_optr = (const float*)bias_hh.data() + hidden_size * 3;
new_bias_reverse.resize(4 * hidden_size);
float* bias = (float*)new_bias_reverse.data();
......@@ -226,11 +226,11 @@ pnnx.Output output 3 0 out out_hidden out_cell
{
const int weight_data_size_g = hidden_size * proj_size;
const float* weight_hh = (const float*)captured_attrs.at("op_0.weight_hh_l0").data.data();
const float* iptr = weight_hh;
const float* fptr = weight_hh + weight_data_size_g;
const float* gptr = weight_hh + weight_data_size_g * 2;
const float* optr = weight_hh + weight_data_size_g * 3;
auto weight_hh = captured_attrs.at("op_0.weight_hh_l0").get_float32_data();
const float* iptr = (const float*)weight_hh.data();
const float* fptr = (const float*)weight_hh.data() + weight_data_size_g;
const float* gptr = (const float*)weight_hh.data() + weight_data_size_g * 2;
const float* optr = (const float*)weight_hh.data() + weight_data_size_g * 3;
new_weight_hh.resize(4 * hidden_size * proj_size);
float* weight = (float*)new_weight_hh.data();
......@@ -250,11 +250,11 @@ pnnx.Output output 3 0 out out_hidden out_cell
{
const int weight_data_size_g = hidden_size * proj_size;
const float* weight_hh = (const float*)captured_attrs.at("op_0.weight_hh_l0_reverse").data.data();
const float* iptr = weight_hh;
const float* fptr = weight_hh + weight_data_size_g;
const float* gptr = weight_hh + weight_data_size_g * 2;
const float* optr = weight_hh + weight_data_size_g * 3;
auto weight_hh = captured_attrs.at("op_0.weight_hh_l0_reverse").get_float32_data();
const float* iptr = (const float*)weight_hh.data();
const float* fptr = (const float*)weight_hh.data() + weight_data_size_g;
const float* gptr = (const float*)weight_hh.data() + weight_data_size_g * 2;
const float* optr = (const float*)weight_hh.data() + weight_data_size_g * 3;
new_weight_hh_reverse.resize(4 * hidden_size * proj_size);
float* weight = (float*)new_weight_hh_reverse.data();
......
......@@ -45,7 +45,7 @@ pnnx.Output output 1 0 out
{
op->params["0"] = captured_params.at("out_features");
op->params["1"] = captured_params.at("bias").b ? 1 : 0;
op->params["2"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["2"] = captured_attrs.at("op_0.weight").elemcount();
op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
......
......@@ -67,9 +67,12 @@ pnnx.Output output 1 0 out
std::vector<float> v_bias(embed_dim);
{
// qkv - embed_dim - embed_dim
const float* wptr = (const float*)captured_attrs.at("op_0.in_proj_weight").data.data();
auto w = captured_attrs.at("op_0.in_proj_weight").get_float32_data();
// qkv - embed_dim
const float* bptr = (const float*)captured_attrs.at("op_0.in_proj_bias").data.data();
auto b = captured_attrs.at("op_0.in_proj_bias").get_float32_data();
const float* wptr = (const float*)w.data();
const float* bptr = (const float*)b.data();
{
memcpy(q_weight.data(), wptr, embed_dim * embed_dim * sizeof(float));
......@@ -235,7 +238,9 @@ pnnx.Output output 1 0 out
std::vector<float> v_bias(embed_dim);
{
// qkv - embed_dim
const float* bptr = (const float*)captured_attrs.at("op_0.in_proj_bias").data.data();
auto b = captured_attrs.at("op_0.in_proj_bias").get_float32_data();
const float* bptr = (const float*)b.data();
{
memcpy(q_bias.data(), bptr, embed_dim * sizeof(float));
......@@ -264,7 +269,9 @@ pnnx.Output output 1 0 out
std::vector<float> v_weight(embed_dim * vdim);
{
// qkv - embed_dim - embed_dim
const float* wptr = (const float*)captured_attrs.at("op_0.in_proj_weight").data.data();
auto w = captured_attrs.at("op_0.in_proj_weight").get_float32_data();
const float* wptr = (const float*)w.data();
{
memcpy(q_weight.data(), wptr, embed_dim * embed_dim * sizeof(float));
......
......@@ -75,8 +75,8 @@ pnnx.Output output 2 0 out out_hidden
// reduce bias_ih and bias_hh
std::vector<float> new_bias;
{
const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0").data.data();
const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0").data.data();
auto bias_ih = captured_attrs.at("op_0.bias_ih_l0").get_float32_data();
auto bias_hh = captured_attrs.at("op_0.bias_hh_l0").get_float32_data();
new_bias.resize(num_output);
float* bias = (float*)new_bias.data();
......@@ -90,8 +90,8 @@ pnnx.Output output 2 0 out out_hidden
{
std::vector<float> new_bias_reverse;
{
const float* bias_ih = (const float*)captured_attrs.at("op_0.bias_ih_l0_reverse").data.data();
const float* bias_hh = (const float*)captured_attrs.at("op_0.bias_hh_l0_reverse").data.data();
auto bias_ih = captured_attrs.at("op_0.bias_ih_l0_reverse").get_float32_data();
auto bias_hh = captured_attrs.at("op_0.bias_hh_l0_reverse").get_float32_data();
new_bias_reverse.resize(num_output);
float* bias = (float*)new_bias_reverse.data();
......
......@@ -105,7 +105,7 @@ pnnx.Output output 1 0 out
const int outch = weight.shape[1];
std::vector<float> new_weight;
{
const float* w = (const float*)weight.data.data();
auto w = weight.get_float32_data();
new_weight.resize(outch * inch);
float* w2 = (float*)new_weight.data();
......@@ -122,7 +122,7 @@ pnnx.Output output 1 0 out
op->params["0"] = outch;
op->params["1"] = 1;
op->params["2"] = (int)(weight.data.size() / sizeof(float));
op->params["2"] = weight.elemcount();
op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
......
......@@ -54,7 +54,7 @@ pnnx.Output output 1 0 out
op->params["4"] = captured_params.at("padding").ai[1];
op->params["14"] = captured_params.at("padding").ai[0];
op->params["5"] = captured_params.at("bias").b ? 1 : 0;
op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["6"] = captured_attrs.at("op_0.weight").elemcount();
op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
......@@ -101,7 +101,7 @@ pnnx.Output output 1 0 out
op->params["4"] = captured_params.at("padding").ai[1];
op->params["14"] = captured_params.at("padding").ai[0];
op->params["5"] = captured_params.at("bias").b ? 1 : 0;
op->params["6"] = (int)(captured_attrs.at("op_0.weight").data.size() / sizeof(float));
op->params["6"] = captured_attrs.at("op_0.weight").elemcount();
op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
......
......@@ -20,6 +20,8 @@
#include <fstream>
#include <iostream>
#include "utils.h"
namespace pnnx {
// from cxxabi bridge
......@@ -33,60 +35,6 @@ extern const Attribute& get_operator_attr(const Operator* op, const char* key);
extern const char* get_param_s(const Parameter& p);
extern std::vector<const char*> get_param_as(const Parameter& p);
static unsigned short float32_to_float16(float value)
{
// 1 : 8 : 23
union
{
unsigned int u;
float f;
} tmp;
tmp.f = value;
// 1 : 8 : 23
unsigned short sign = (tmp.u & 0x80000000) >> 31;
unsigned short exponent = (tmp.u & 0x7F800000) >> 23;
unsigned int significand = tmp.u & 0x7FFFFF;
// NCNN_LOGE("%d %d %d", sign, exponent, significand);
// 1 : 5 : 10
unsigned short fp16;
if (exponent == 0)
{
// zero or denormal, always underflow
fp16 = (sign << 15) | (0x00 << 10) | 0x00;
}
else if (exponent == 0xFF)
{
// infinity or NaN
fp16 = (sign << 15) | (0x1F << 10) | (significand ? 0x200 : 0x00);
}
else
{
// normalized
short newexp = exponent + (-127 + 15);
if (newexp >= 31)
{
// overflow, return infinity
fp16 = (sign << 15) | (0x1F << 10) | 0x00;
}
else if (newexp <= 0)
{
// Some normal fp32 cannot be expressed as normal fp16
fp16 = (sign << 15) | (0x00 << 10) | 0x00;
}
else
{
// normal fp16
fp16 = (sign << 15) | (newexp << 10) | (significand >> 13);
}
}
return fp16;
}
int save_onnx(const Graph& g, const char* onnxpath, int fp16)
{
onnx::ModelProto model;
......
......@@ -27,4 +27,109 @@ const torch::jit::Node* find_node_by_kind(const std::shared_ptr<torch::jit::Grap
return 0;
}
unsigned short float32_to_float16(float value)
{
// 1 : 8 : 23
union
{
unsigned int u;
float f;
} tmp;
tmp.f = value;
// 1 : 8 : 23
unsigned short sign = (tmp.u & 0x80000000) >> 31;
unsigned short exponent = (tmp.u & 0x7F800000) >> 23;
unsigned int significand = tmp.u & 0x7FFFFF;
// NCNN_LOGE("%d %d %d", sign, exponent, significand);
// 1 : 5 : 10
unsigned short fp16;
if (exponent == 0)
{
// zero or denormal, always underflow
fp16 = (sign << 15) | (0x00 << 10) | 0x00;
}
else if (exponent == 0xFF)
{
// infinity or NaN
fp16 = (sign << 15) | (0x1F << 10) | (significand ? 0x200 : 0x00);
}
else
{
// normalized
short newexp = exponent + (-127 + 15);
if (newexp >= 31)
{
// overflow, return infinity
fp16 = (sign << 15) | (0x1F << 10) | 0x00;
}
else if (newexp <= 0)
{
// Some normal fp32 cannot be expressed as normal fp16
fp16 = (sign << 15) | (0x00 << 10) | 0x00;
}
else
{
// normal fp16
fp16 = (sign << 15) | (newexp << 10) | (significand >> 13);
}
}
return fp16;
}
float float16_to_float32(unsigned short value)
{
// 1 : 5 : 10
unsigned short sign = (value & 0x8000) >> 15;
unsigned short exponent = (value & 0x7c00) >> 10;
unsigned short significand = value & 0x03FF;
// NCNN_LOGE("%d %d %d", sign, exponent, significand);
// 1 : 8 : 23
union
{
unsigned int u;
float f;
} tmp;
if (exponent == 0)
{
if (significand == 0)
{
// zero
tmp.u = (sign << 31);
}
else
{
// denormal
exponent = 0;
// find non-zero bit
while ((significand & 0x200) == 0)
{
significand <<= 1;
exponent++;
}
significand <<= 1;
significand &= 0x3FF;
tmp.u = (sign << 31) | ((-exponent + (-15 + 127)) << 23) | (significand << 13);
}
}
else if (exponent == 0x1F)
{
// infinity or NaN
tmp.u = (sign << 31) | (0xFF << 23) | (significand << 13);
}
else
{
// normalized
tmp.u = (sign << 31) | ((exponent + (-15 + 127)) << 23) | (significand << 13);
}
return tmp.f;
}
} // namespace pnnx
......@@ -22,6 +22,10 @@ namespace pnnx {
const torch::jit::Node* find_node_by_kind(const std::shared_ptr<torch::jit::Graph>& graph, const std::string& kind);
unsigned short float32_to_float16(float value);
float float16_to_float32(unsigned short value);
} // namespace pnnx
#endif // PNNX_UTILS_H
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册