未验证 提交 246e71c5 编写于 作者: N nihui 提交者: GitHub

implement atan2 (#4516)

上级 92e75105
......@@ -163,6 +163,8 @@ Operation type:
- 7 = RSUB
- 8 = RDIV
- 9 = RPOW
- 10 = ATAN2
- 11 = RATAN2
# BNLL
```
......
......@@ -555,6 +555,8 @@ MAKE_FUNCTION(binary_op_rdiv, y / x, vdivq_f32(y, x))
MAKE_FUNCTION(binary_op_rdiv, y / x, div_ps(y, x))
#endif
MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x))
MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y))
MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x))
// *INDENT-ON*
// clang-format on
......@@ -576,6 +578,8 @@ static int binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Op
if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_scalar<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_scalar<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -592,9 +596,11 @@ static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_typ
if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast<binary_op_max>(a, b, c, opt);
if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast<binary_op_min>(a, b, c, opt);
if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast<binary_op_pow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast<binary_op_sub>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast<binary_op_div>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast<binary_op_pow>(b, a, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_no_broadcast<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_no_broadcast<binary_op_atan2>(b, a, c, opt);
// should never reach here
return 0;
......@@ -629,6 +635,8 @@ static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, int op_
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner<binary_op_rsub>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner<binary_op_rdiv>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner<binary_op_rpow>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_inner<binary_op_atan2>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_inner<binary_op_ratan2>(a, b2, c, opt);
// should never reach here
return 0;
......@@ -648,6 +656,8 @@ static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, int op_
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_outer<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_outer<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -667,6 +677,8 @@ static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, int op_typ
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_20<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_20<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -677,9 +689,11 @@ static int get_reverse_op_type(int op_type)
if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB;
if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV;
if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW;
if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2;
if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB;
if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV;
if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW;
if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2;
return op_type;
}
......@@ -775,6 +789,8 @@ int BinaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_RSUB) return binary_op_scalar_inplace<binary_op_rsub>(bottom_top_blob, b, opt);
if (op_type == Operation_RDIV) return binary_op_scalar_inplace<binary_op_rdiv>(bottom_top_blob, b, opt);
if (op_type == Operation_RPOW) return binary_op_scalar_inplace<binary_op_rpow>(bottom_top_blob, b, opt);
if (op_type == Operation_ATAN2) return binary_op_scalar_inplace<binary_op_atan2>(bottom_top_blob, b, opt);
if (op_type == Operation_RATAN2) return binary_op_scalar_inplace<binary_op_ratan2>(bottom_top_blob, b, opt);
return 0;
}
......@@ -1263,6 +1279,8 @@ static int binary_op_scalar_bf16s(const Mat& a, float b, Mat& c, int op_type, co
if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar_bf16s<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar_bf16s<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar_bf16s<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_scalar_bf16s<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_scalar_bf16s<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -1279,9 +1297,11 @@ static int binary_op_no_broadcast_bf16s(const Mat& a, const Mat& b, Mat& c, int
if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast_bf16s<binary_op_max>(a, b, c, opt);
if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast_bf16s<binary_op_min>(a, b, c, opt);
if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast_bf16s<binary_op_pow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast_bf16s<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast_bf16s<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast_bf16s<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast_bf16s<binary_op_sub>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast_bf16s<binary_op_div>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast_bf16s<binary_op_pow>(b, a, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_no_broadcast_bf16s<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_no_broadcast_bf16s<binary_op_atan2>(b, a, c, opt);
// should never reach here
return 0;
......@@ -1316,6 +1336,8 @@ static int binary_op_broadcast_inner_bf16s(const Mat& a, const Mat& b, Mat& c, i
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner_bf16s<binary_op_rsub>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner_bf16s<binary_op_rdiv>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner_bf16s<binary_op_rpow>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_inner_bf16s<binary_op_atan2>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_inner_bf16s<binary_op_ratan2>(a, b2, c, opt);
// should never reach here
return 0;
......@@ -1335,6 +1357,8 @@ static int binary_op_broadcast_outer_bf16s(const Mat& a, const Mat& b, Mat& c, i
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer_bf16s<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer_bf16s<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer_bf16s<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_outer_bf16s<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_outer_bf16s<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -1354,6 +1378,8 @@ static int binary_op_broadcast_20_bf16s(const Mat& a, const Mat& b, Mat& c, int
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20_bf16s<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20_bf16s<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20_bf16s<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_20_bf16s<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_20_bf16s<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -1427,6 +1453,8 @@ int BinaryOp_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt)
if (op_type == Operation_RSUB) return binary_op_scalar_inplace_bf16s<binary_op_rsub>(bottom_top_blob, b, opt);
if (op_type == Operation_RDIV) return binary_op_scalar_inplace_bf16s<binary_op_rdiv>(bottom_top_blob, b, opt);
if (op_type == Operation_RPOW) return binary_op_scalar_inplace_bf16s<binary_op_rpow>(bottom_top_blob, b, opt);
if (op_type == Operation_ATAN2) return binary_op_scalar_inplace_bf16s<binary_op_atan2>(bottom_top_blob, b, opt);
if (op_type == Operation_RATAN2) return binary_op_scalar_inplace_bf16s<binary_op_ratan2>(bottom_top_blob, b, opt);
return 0;
}
......
......@@ -598,6 +598,8 @@ MAKE_FUNCTION(binary_op_pow_fp16s, (__fp16)pow(x, y), vcvt_f16_f32(pow_ps(vcvt_f
MAKE_FUNCTION(binary_op_rsub_fp16s, y - x, vsub_f16(y, x), vsubq_f16(y, x))
MAKE_FUNCTION(binary_op_rdiv_fp16s, y / x, vdiv_f16(y, x), vdivq_f16(y, x))
MAKE_FUNCTION(binary_op_rpow_fp16s, (__fp16)pow(y, x), vcvt_f16_f32(pow_ps(vcvt_f32_f16(y), vcvt_f32_f16(x))), vcombine_f16(vcvt_f16_f32(pow_ps(vcvt_f32_f16(vget_low_f16(y)), vcvt_f32_f16(vget_low_f16(x)))), vcvt_f16_f32(pow_ps(vcvt_f32_f16(vget_high_f16(y)), vcvt_f32_f16(vget_high_f16(x))))))
MAKE_FUNCTION(binary_op_atan2_fp16s, (__fp16)atan2(x, y), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(x), vcvt_f32_f16(y))), vcombine_f16(vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_low_f16(y)))), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_high_f16(x)), vcvt_f32_f16(vget_high_f16(y))))))
MAKE_FUNCTION(binary_op_ratan2_fp16s, (__fp16)atan2(y, x), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(y), vcvt_f32_f16(x))), vcombine_f16(vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_low_f16(y)), vcvt_f32_f16(vget_low_f16(x)))), vcvt_f16_f32(atan2_ps(vcvt_f32_f16(vget_high_f16(y)), vcvt_f32_f16(vget_high_f16(x))))))
// *INDENT-ON*
// clang-format on
......@@ -619,6 +621,8 @@ static int binary_op_scalar_fp16s(const Mat& a, float b, Mat& c, int op_type, co
if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar_fp16s<binary_op_rsub_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar_fp16s<binary_op_rdiv_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar_fp16s<binary_op_rpow_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_scalar_fp16s<binary_op_atan2_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_scalar_fp16s<binary_op_ratan2_fp16s>(a, b, c, opt);
// should never reach here
return 0;
......@@ -635,9 +639,11 @@ static int binary_op_no_broadcast_fp16s(const Mat& a, const Mat& b, Mat& c, int
if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast_fp16s<binary_op_max_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast_fp16s<binary_op_min_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast_fp16s<binary_op_pow_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast_fp16s<binary_op_rsub_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast_fp16s<binary_op_rdiv_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast_fp16s<binary_op_rpow_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast_fp16s<binary_op_sub_fp16s>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast_fp16s<binary_op_div_fp16s>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast_fp16s<binary_op_pow_fp16s>(b, a, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_no_broadcast_fp16s<binary_op_atan2_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_no_broadcast_fp16s<binary_op_atan2_fp16s>(b, a, c, opt);
// should never reach here
return 0;
......@@ -672,6 +678,8 @@ static int binary_op_broadcast_inner_fp16s(const Mat& a, const Mat& b, Mat& c, i
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner_fp16s<binary_op_rsub_fp16s>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner_fp16s<binary_op_rdiv_fp16s>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner_fp16s<binary_op_rpow_fp16s>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_inner_fp16s<binary_op_atan2_fp16s>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_inner_fp16s<binary_op_ratan2_fp16s>(a, b2, c, opt);
// should never reach here
return 0;
......@@ -691,6 +699,8 @@ static int binary_op_broadcast_outer_fp16s(const Mat& a, const Mat& b, Mat& c, i
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer_fp16s<binary_op_rsub_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer_fp16s<binary_op_rdiv_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer_fp16s<binary_op_rpow_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_outer_fp16s<binary_op_atan2_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_outer_fp16s<binary_op_ratan2_fp16s>(a, b, c, opt);
// should never reach here
return 0;
......@@ -710,6 +720,8 @@ static int binary_op_broadcast_20_fp16s(const Mat& a, const Mat& b, Mat& c, int
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20_fp16s<binary_op_rsub_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20_fp16s<binary_op_rdiv_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20_fp16s<binary_op_rpow_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_20_fp16s<binary_op_atan2_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_20_fp16s<binary_op_ratan2_fp16s>(a, b, c, opt);
// should never reach here
return 0;
......@@ -720,9 +732,11 @@ static int get_reverse_op_type(int op_type)
if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB;
if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV;
if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW;
if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2;
if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB;
if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV;
if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW;
if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2;
return op_type;
}
......@@ -794,6 +808,8 @@ int BinaryOp_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt)
if (op_type == Operation_RSUB) return binary_op_scalar_inplace_fp16s<binary_op_rsub_fp16s>(bottom_top_blob, (__fp16)b, opt);
if (op_type == Operation_RDIV) return binary_op_scalar_inplace_fp16s<binary_op_rdiv_fp16s>(bottom_top_blob, (__fp16)b, opt);
if (op_type == Operation_RPOW) return binary_op_scalar_inplace_fp16s<binary_op_rpow_fp16s>(bottom_top_blob, (__fp16)b, opt);
if (op_type == Operation_ATAN2) return binary_op_scalar_inplace_fp16s<binary_op_atan2_fp16s>(bottom_top_blob, (__fp16)b, opt);
if (op_type == Operation_RATAN2) return binary_op_scalar_inplace_fp16s<binary_op_ratan2_fp16s>(bottom_top_blob, (__fp16)b, opt);
return 0;
}
......
......@@ -412,5 +412,17 @@ static inline float32x4_t acos_ps(float32x4_t x)
return yacos;
}
static inline float32x4_t atan2_ps(float32x4_t a, float32x4_t b)
{
//TODO neon optimize
float tmpx[4];
float tmpy[4];
vst1q_f32(tmpx, a);
vst1q_f32(tmpy, b);
for (int i = 0; i < 4; i++)
tmpx[i] = atan2(tmpx[i], tmpy[i]);
return vld1q_f32(tmpx);
}
#include "neon_mathfun_tanh.h"
#endif // NEON_MATHFUN_H
......@@ -412,6 +412,22 @@ struct binary_op_rpow
}
};
struct binary_op_atan2
{
float operator()(const float& x, const float& y) const
{
return (float)atan2(x, y);
}
};
struct binary_op_ratan2
{
float operator()(const float& x, const float& y) const
{
return (float)atan2(y, x);
}
};
static int binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Option& opt)
{
if (op_type == BinaryOp::Operation_ADD) return binary_op_scalar<binary_op_add>(a, b, c, opt);
......@@ -424,6 +440,8 @@ static int binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Op
if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_scalar<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_scalar<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -438,9 +456,11 @@ static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_typ
if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast<binary_op_max>(a, b, c, opt);
if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast<binary_op_min>(a, b, c, opt);
if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast<binary_op_pow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast<binary_op_sub>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast<binary_op_div>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast<binary_op_pow>(b, a, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_no_broadcast<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_no_broadcast<binary_op_atan2>(b, a, c, opt);
// should never reach here
return 0;
......@@ -473,6 +493,8 @@ static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, int op_
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner<binary_op_rsub>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner<binary_op_rdiv>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner<binary_op_rpow>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_inner<binary_op_atan2>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_inner<binary_op_ratan2>(a, b2, c, opt);
// should never reach here
return 0;
......@@ -490,6 +512,8 @@ static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, int op_
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_outer<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_outer<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -507,6 +531,8 @@ static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, int op_typ
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_20<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_20<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -517,9 +543,11 @@ static int get_reverse_op_type(int op_type)
if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB;
if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV;
if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW;
if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2;
if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB;
if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV;
if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW;
if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2;
return op_type;
}
......@@ -594,6 +622,8 @@ int BinaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_RSUB) return binary_op_scalar_inplace<binary_op_rsub>(bottom_top_blob, b, opt);
if (op_type == Operation_RDIV) return binary_op_scalar_inplace<binary_op_rdiv>(bottom_top_blob, b, opt);
if (op_type == Operation_RPOW) return binary_op_scalar_inplace<binary_op_rpow>(bottom_top_blob, b, opt);
if (op_type == Operation_ATAN2) return binary_op_scalar_inplace<binary_op_atan2>(bottom_top_blob, b, opt);
if (op_type == Operation_RATAN2) return binary_op_scalar_inplace<binary_op_ratan2>(bottom_top_blob, b, opt);
// should nerver reach here
return 0;
......
......@@ -43,7 +43,9 @@ public:
Operation_POW = 6,
Operation_RSUB = 7,
Operation_RDIV = 8,
Operation_RPOW = 9
Operation_RPOW = 9,
Operation_ATAN2 = 10,
Operation_RATAN2 = 11
};
public:
......
......@@ -554,6 +554,8 @@ MAKE_FUNCTION(binary_op_pow, (float)pow(x, y), pow_ps(x, y))
MAKE_FUNCTION(binary_op_rsub, y - x, __lsx_vfsub_s(y, x))
MAKE_FUNCTION(binary_op_rdiv, y / x, __lsx_vfdiv_s(y, x))
MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x))
MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y))
MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x))
// *INDENT-ON*
// clang-format on
......@@ -575,6 +577,8 @@ static int binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Op
if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_scalar<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_scalar<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -591,9 +595,11 @@ static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_typ
if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast<binary_op_max>(a, b, c, opt);
if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast<binary_op_min>(a, b, c, opt);
if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast<binary_op_pow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast<binary_op_sub>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast<binary_op_div>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast<binary_op_pow>(b, a, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_no_broadcast<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_no_broadcast<binary_op_atan2>(b, a, c, opt);
// should never reach here
return 0;
......@@ -628,6 +634,8 @@ static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, int op_
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner<binary_op_rsub>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner<binary_op_rdiv>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner<binary_op_rpow>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_inner<binary_op_atan2>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_inner<binary_op_ratan2>(a, b2, c, opt);
// should never reach here
return 0;
......@@ -647,6 +655,8 @@ static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, int op_
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_outer<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_outer<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -666,6 +676,8 @@ static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, int op_typ
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_20<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_20<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -676,9 +688,11 @@ static int get_reverse_op_type(int op_type)
if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB;
if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV;
if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW;
if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2;
if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB;
if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV;
if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW;
if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2;
return op_type;
}
......@@ -750,6 +764,8 @@ int BinaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt)
if (op_type == Operation_RSUB) return binary_op_scalar_inplace<binary_op_rsub>(bottom_top_blob, b, opt);
if (op_type == Operation_RDIV) return binary_op_scalar_inplace<binary_op_rdiv>(bottom_top_blob, b, opt);
if (op_type == Operation_RPOW) return binary_op_scalar_inplace<binary_op_rpow>(bottom_top_blob, b, opt);
if (op_type == Operation_ATAN2) return binary_op_scalar_inplace<binary_op_atan2>(bottom_top_blob, b, opt);
if (op_type == Operation_RATAN2) return binary_op_scalar_inplace<binary_op_ratan2>(bottom_top_blob, b, opt);
return 0;
}
......
......@@ -255,4 +255,18 @@ static inline __m128 sigmoid_ps(__m128 _v)
return __lsx_vfdiv_s(_one, _v);
}
static inline __m128 atan2_ps(__m128 a, __m128 b)
{
//TODO lsx optimize
float tmpx[4];
float tmpy[4];
__lsx_vst(a, tmpx, 0);
__lsx_vst(b, tmpy, 0);
tmpx[0] = atan2(tmpx[0], tmpy[0]);
tmpx[1] = atan2(tmpx[1], tmpy[1]);
tmpx[2] = atan2(tmpx[2], tmpy[2]);
tmpx[3] = atan2(tmpx[3], tmpy[3]);
return (__m128)__lsx_vld(tmpx, 0);
}
#endif // LSX_MATHFUN_H
......@@ -550,6 +550,8 @@ MAKE_FUNCTION(binary_op_pow, (float)pow(x, y), pow_ps(x, y))
MAKE_FUNCTION(binary_op_rsub, y - x, __msa_fsub_w(y, x))
MAKE_FUNCTION(binary_op_rdiv, y / x, __msa_fdiv_w(y, x))
MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x))
MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y))
MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x))
// *INDENT-ON*
// clang-format on
......@@ -571,6 +573,8 @@ static int binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Op
if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_scalar<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_scalar<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -587,9 +591,11 @@ static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_typ
if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast<binary_op_max>(a, b, c, opt);
if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast<binary_op_min>(a, b, c, opt);
if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast<binary_op_pow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast<binary_op_sub>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast<binary_op_div>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast<binary_op_pow>(b, a, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_no_broadcast<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_no_broadcast<binary_op_atan2>(b, a, c, opt);
// should never reach here
return 0;
......@@ -624,6 +630,8 @@ static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, int op_
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner<binary_op_rsub>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner<binary_op_rdiv>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner<binary_op_rpow>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_inner<binary_op_atan2>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_inner<binary_op_ratan2>(a, b2, c, opt);
// should never reach here
return 0;
......@@ -643,6 +651,8 @@ static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, int op_
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_outer<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_outer<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -662,6 +672,8 @@ static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, int op_typ
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_20<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_20<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -672,9 +684,11 @@ static int get_reverse_op_type(int op_type)
if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB;
if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV;
if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW;
if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2;
if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB;
if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV;
if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW;
if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2;
return op_type;
}
......@@ -746,6 +760,8 @@ int BinaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons
if (op_type == Operation_RSUB) return binary_op_scalar_inplace<binary_op_rsub>(bottom_top_blob, b, opt);
if (op_type == Operation_RDIV) return binary_op_scalar_inplace<binary_op_rdiv>(bottom_top_blob, b, opt);
if (op_type == Operation_RPOW) return binary_op_scalar_inplace<binary_op_rpow>(bottom_top_blob, b, opt);
if (op_type == Operation_ATAN2) return binary_op_scalar_inplace<binary_op_atan2>(bottom_top_blob, b, opt);
if (op_type == Operation_RATAN2) return binary_op_scalar_inplace<binary_op_ratan2>(bottom_top_blob, b, opt);
return 0;
}
......
......@@ -253,4 +253,18 @@ static inline v4f32 sigmoid_ps(v4f32 _v)
return __msa_fdiv_w(_one, _v);
}
static inline v4f32 atan2_ps(v4f32 a, v4f32 b)
{
//TODO msa optimize
float tmpx[4];
float tmpy[4];
__msa_st_w((v4i32)a, tmpx, 0);
__msa_st_w((v4i32)b, tmpy, 0);
tmpx[0] = atan2(tmpx[0], tmpy[0]);
tmpx[1] = atan2(tmpx[1], tmpy[1]);
tmpx[2] = atan2(tmpx[2], tmpy[2]);
tmpx[3] = atan2(tmpx[3], tmpy[3]);
return (v4f32)__msa_ld_w(tmpx, 0);
}
#endif // MSA_MATHFUN_H
......@@ -579,6 +579,8 @@ MAKE_FUNCTION(binary_op_pow, (float)pow(x, y), pow_ps(x, y, vl), pow_ps(x, vfmv_
MAKE_FUNCTION(binary_op_rsub, y - x, vfsub_vv_f32m8(y, x, vl), vfrsub_vf_f32m8(x, y, vl), vfsub_vf_f32m8(y, x, vl))
MAKE_FUNCTION(binary_op_rdiv, y / x, vfdiv_vv_f32m8(y, x, vl), vfrdiv_vf_f32m8(x, y, vl), vfdiv_vf_f32m8(y, x, vl))
MAKE_FUNCTION(binary_op_rpow, (float)pow(y, x), pow_ps(y, x, vl), pow_ps(vfmv_v_f_f32m8(y, vl), x, vl), pow_ps(y, vfmv_v_f_f32m8(x, vl), vl))
MAKE_FUNCTION(binary_op_atan2, (float)atan2(x, y), atan2_ps(x, y, vl), atan2_ps(x, vfmv_v_f_f32m8(y, vl), vl), atan2_ps(vfmv_v_f_f32m8(x, vl), y, vl))
MAKE_FUNCTION(binary_op_ratan2, (float)atan2(y, x), atan2_ps(y, x, vl), atan2_ps(vfmv_v_f_f32m8(y, vl), x, vl), atan2_ps(y, vfmv_v_f_f32m8(x, vl), vl))
// *INDENT-ON*
// clang-format on
......@@ -600,6 +602,8 @@ static int binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Op
if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_scalar<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_scalar<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -616,9 +620,11 @@ static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_typ
if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast<binary_op_max>(a, b, c, opt);
if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast<binary_op_min>(a, b, c, opt);
if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast<binary_op_pow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast<binary_op_sub>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast<binary_op_div>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast<binary_op_pow>(b, a, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_no_broadcast<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_no_broadcast<binary_op_atan2>(b, a, c, opt);
// should never reach here
return 0;
......@@ -653,6 +659,8 @@ static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, int op_
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner<binary_op_rsub>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner<binary_op_rdiv>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner<binary_op_rpow>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_inner<binary_op_atan2>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_inner<binary_op_ratan2>(a, b2, c, opt);
// should never reach here
return 0;
......@@ -672,6 +680,8 @@ static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, int op_
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_outer<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_outer<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -691,6 +701,8 @@ static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, int op_typ
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_20<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_20<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -701,9 +713,11 @@ static int get_reverse_op_type(int op_type)
if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB;
if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV;
if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW;
if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2;
if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB;
if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV;
if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW;
if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2;
return op_type;
}
......@@ -793,6 +807,8 @@ int BinaryOp_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) con
if (op_type == Operation_RSUB) return binary_op_scalar_inplace<binary_op_rsub>(bottom_top_blob, b, opt);
if (op_type == Operation_RDIV) return binary_op_scalar_inplace<binary_op_rdiv>(bottom_top_blob, b, opt);
if (op_type == Operation_RPOW) return binary_op_scalar_inplace<binary_op_rpow>(bottom_top_blob, b, opt);
if (op_type == Operation_ATAN2) return binary_op_scalar_inplace<binary_op_atan2>(bottom_top_blob, b, opt);
if (op_type == Operation_RATAN2) return binary_op_scalar_inplace<binary_op_ratan2>(bottom_top_blob, b, opt);
return 0;
}
......@@ -1244,6 +1260,8 @@ MAKE_FUNCTION(binary_op_pow_fp16s, (__fp16)pow((float)x, (float)y), pow_ps(x, y,
MAKE_FUNCTION(binary_op_rsub_fp16s, y - x, vfsub_vv_f16m8(y, x, vl), vfrsub_vf_f16m8(x, y, vl), vfsub_vf_f16m8(y, x, vl))
MAKE_FUNCTION(binary_op_rdiv_fp16s, y / x, vfdiv_vv_f16m8(y, x, vl), vfrdiv_vf_f16m8(x, y, vl), vfdiv_vf_f16m8(y, x, vl))
MAKE_FUNCTION(binary_op_rpow_fp16s, (__fp16)pow((float)y, (float)x), pow_ps(y, x, vl), pow_ps(vfmv_v_f_f16m8(y, vl), x, vl), pow_ps(y, vfmv_v_f_f16m8(x, vl), vl))
MAKE_FUNCTION(binary_op_atan2_fp16s, (__fp16)atan2((float)x, (float)y), atan2_ps(x, y, vl), atan2_ps(x, vfmv_v_f_f16m8(y, vl), vl), atan2_ps(vfmv_v_f_f16m8(x, vl), y, vl))
MAKE_FUNCTION(binary_op_ratan2_fp16s, (__fp16)atan2((float)y, (float)x), atan2_ps(y, x, vl), atan2_ps(vfmv_v_f_f16m8(y, vl), x, vl), atan2_ps(y, vfmv_v_f_f16m8(x, vl), vl))
// *INDENT-ON*
// clang-format on
......@@ -1265,6 +1283,8 @@ static int binary_op_scalar_fp16s(const Mat& a, __fp16 b, Mat& c, int op_type, c
if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar_fp16s<binary_op_rsub_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar_fp16s<binary_op_rdiv_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar_fp16s<binary_op_rpow_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_scalar_fp16s<binary_op_atan2_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_scalar_fp16s<binary_op_ratan2_fp16s>(a, b, c, opt);
// should never reach here
return 0;
......@@ -1281,9 +1301,11 @@ static int binary_op_no_broadcast_fp16s(const Mat& a, const Mat& b, Mat& c, int
if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast_fp16s<binary_op_max_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast_fp16s<binary_op_min_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast_fp16s<binary_op_pow_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast_fp16s<binary_op_rsub_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast_fp16s<binary_op_rdiv_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast_fp16s<binary_op_rpow_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast_fp16s<binary_op_sub_fp16s>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast_fp16s<binary_op_div_fp16s>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast_fp16s<binary_op_pow_fp16s>(b, a, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_no_broadcast_fp16s<binary_op_atan2_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_no_broadcast_fp16s<binary_op_atan2_fp16s>(b, a, c, opt);
// should never reach here
return 0;
......@@ -1318,6 +1340,8 @@ static int binary_op_broadcast_inner_fp16s(const Mat& a, const Mat& b, Mat& c, i
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner_fp16s<binary_op_rsub_fp16s>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner_fp16s<binary_op_rdiv_fp16s>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner_fp16s<binary_op_rpow_fp16s>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_inner_fp16s<binary_op_atan2_fp16s>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_inner_fp16s<binary_op_ratan2_fp16s>(a, b2, c, opt);
// should never reach here
return 0;
......@@ -1337,6 +1361,8 @@ static int binary_op_broadcast_outer_fp16s(const Mat& a, const Mat& b, Mat& c, i
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer_fp16s<binary_op_rsub_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer_fp16s<binary_op_rdiv_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer_fp16s<binary_op_rpow_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_outer_fp16s<binary_op_atan2_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_outer_fp16s<binary_op_ratan2_fp16s>(a, b, c, opt);
// should never reach here
return 0;
......@@ -1356,6 +1382,8 @@ static int binary_op_broadcast_20_fp16s(const Mat& a, const Mat& b, Mat& c, int
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20_fp16s<binary_op_rsub_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20_fp16s<binary_op_rdiv_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20_fp16s<binary_op_rpow_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_20_fp16s<binary_op_atan2_fp16s>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_20_fp16s<binary_op_ratan2_fp16s>(a, b, c, opt);
// should never reach here
return 0;
......@@ -1429,6 +1457,8 @@ int BinaryOp_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& op
if (op_type == Operation_RSUB) return binary_op_scalar_inplace_fp16s<binary_op_rsub_fp16s>(bottom_top_blob, (__fp16)b, opt);
if (op_type == Operation_RDIV) return binary_op_scalar_inplace_fp16s<binary_op_rdiv_fp16s>(bottom_top_blob, (__fp16)b, opt);
if (op_type == Operation_RPOW) return binary_op_scalar_inplace_fp16s<binary_op_rpow_fp16s>(bottom_top_blob, (__fp16)b, opt);
if (op_type == Operation_ATAN2) return binary_op_scalar_inplace_fp16s<binary_op_atan2_fp16s>(bottom_top_blob, (__fp16)b, opt);
if (op_type == Operation_RATAN2) return binary_op_scalar_inplace_fp16s<binary_op_ratan2_fp16s>(bottom_top_blob, (__fp16)b, opt);
return 0;
}
......
......@@ -534,4 +534,24 @@ _RVV_FLOAT32_ERFC_OP(2, 16)
_RVV_FLOAT32_ERFC_OP(4, 8)
_RVV_FLOAT32_ERFC_OP(8, 4)
//TODO rvv optimize
#define _RVV_FLOAT32_ATAN2_OP(LMUL, MLEN) \
static inline vfloat32m##LMUL##_t atan2_ps(vfloat32m##LMUL##_t a, vfloat32m##LMUL##_t b, size_t vl) \
{ \
std::vector<float> tmpx(vl); \
std::vector<float> tmpy(vl); \
vse32_v_f32m##LMUL(tmpx.data(), a, vl); \
vse32_v_f32m##LMUL(tmpy.data(), b, vl); \
for (int i = 0; i < vl; i++) \
{ \
tmpx[i] = atan2(tmpx[i], tmpy[i]); \
} \
return vle32_v_f32m##LMUL(tmpx.data(), vl); \
}
_RVV_FLOAT32_ATAN2_OP(1, 32)
_RVV_FLOAT32_ATAN2_OP(2, 16)
_RVV_FLOAT32_ATAN2_OP(4, 8)
_RVV_FLOAT32_ATAN2_OP(8, 4)
#endif // RVV_MATHFUN_H
......@@ -370,4 +370,24 @@ _RVV_FLOAT16_SIGMOID_OP(2, 8)
_RVV_FLOAT16_SIGMOID_OP(4, 4)
_RVV_FLOAT16_SIGMOID_OP(8, 2)
//TODO rvv optimize
#define _RVV_FLOAT16_ATAN2_OP(LMUL, MLEN) \
static inline vfloat16m##LMUL##_t atan2_ps(vfloat16m##LMUL##_t a, vfloat16m##LMUL##_t b, size_t vl) \
{ \
std::vector<__fp16> tmpx(vl); \
std::vector<__fp16> tmpy(vl); \
vse16_v_f16m##LMUL(tmpx.data(), a, vl); \
vse16_v_f16m##LMUL(tmpy.data(), b, vl); \
for (int i = 0; i < vl; i++) \
{ \
tmpx[i] = (__fp16)atan2((float)tmpx[i], (float)tmpy[i]); \
} \
return vle16_v_f16m##LMUL(tmpx.data(), vl); \
}
_RVV_FLOAT16_ATAN2_OP(1, 32)
_RVV_FLOAT16_ATAN2_OP(2, 16)
_RVV_FLOAT16_ATAN2_OP(4, 8)
_RVV_FLOAT16_ATAN2_OP(8, 4)
#endif // RVV_MATHFUN_FP16S_H
......@@ -48,9 +48,11 @@ static int get_reverse_op_type(int op_type)
if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB;
if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV;
if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW;
if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2;
if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB;
if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV;
if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW;
if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2;
return op_type;
}
......
......@@ -130,6 +130,13 @@ void main()
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;
if (op_type == 9) res = pow(v2, v1);
#if NCNN_moltenvk
if (op_type == 10) res = afp(atan(float(v1), float(v2)));
if (op_type == 11) res = afp(atan(float(v2), float(v1)));
#else
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
#if NCNN_image_shader
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
......
......@@ -184,6 +184,13 @@ void main()
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;
if (op_type == 9) res = pow(v2, v1);
#if NCNN_moltenvk
if (op_type == 10) res = afp(atan(float(v1), float(v2)));
if (op_type == 11) res = afp(atan(float(v2), float(v1)));
#else
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
#if NCNN_image_shader
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
......
......@@ -184,6 +184,13 @@ void main()
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;
if (op_type == 9) res = pow(v2, v1);
#if NCNN_moltenvk
if (op_type == 10) res = afpvec4(atan(vec4(v1), vec4(v2)));
if (op_type == 11) res = afpvec4(atan(vec4(v2), vec4(v1)));
#else
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
......
......@@ -225,6 +225,26 @@ void main()
res[0] = pow(v2[0], v1[0]);
res[1] = pow(v2[1], v1[1]);
}
if (op_type == 10)
{
#if NCNN_moltenvk
res[0] = afpvec4(atan(vec4(v1[0]), vec4(v2[0])));
res[1] = afpvec4(atan(vec4(v1[1]), vec4(v2[1])));
#else
res[0] = atan(v1[0], v2[0]);
res[1] = atan(v1[1], v2[1]);
#endif
}
if (op_type == 11)
{
#if NCNN_moltenvk
res[0] = afpvec4(atan(vec4(v2[0]), vec4(v1[0])));
res[1] = afpvec4(atan(vec4(v2[1]), vec4(v1[1])));
#else
res[0] = atan(v2[0], v1[0]);
res[1] = atan(v2[1], v1[1]);
#endif
}
#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
......
......@@ -119,6 +119,13 @@ void main()
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;
if (op_type == 9) res = pow(v2, v1);
#if NCNN_moltenvk
if (op_type == 10) res = afp(atan(float(v1), float(v2)));
if (op_type == 11) res = afp(atan(float(v2), float(v1)));
#else
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
#if NCNN_image_shader
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);
......
......@@ -119,6 +119,13 @@ void main()
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;
if (op_type == 9) res = pow(v2, v1);
#if NCNN_moltenvk
if (op_type == 10) res = afpvec4(atan(vec4(v1), vec4(v2)));
if (op_type == 11) res = afpvec4(atan(vec4(v2), vec4(v1)));
#else
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
......
......@@ -163,6 +163,26 @@ void main()
res[0] = pow(v2[0], v1[0]);
res[1] = pow(v2[1], v1[1]);
}
if (op_type == 10)
{
#if NCNN_moltenvk
res[0] = afpvec4(atan(vec4(v1[0]), vec4(v2[0])));
res[1] = afpvec4(atan(vec4(v1[1]), vec4(v2[1])));
#else
res[0] = atan(v1[0], v2[0]);
res[1] = atan(v1[1], v2[1]);
#endif
}
if (op_type == 11)
{
#if NCNN_moltenvk
res[0] = afpvec4(atan(vec4(v2[0]), vec4(v1[0])));
res[1] = afpvec4(atan(vec4(v2[1]), vec4(v1[1])));
#else
res[0] = atan(v2[0], v1[0]);
res[1] = atan(v2[1], v1[1]);
#endif
}
#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
......
......@@ -121,6 +121,13 @@ void main()
if (op_type == 7) res = v2 - v1;
if (op_type == 8) res = v2 / v1;
if (op_type == 9) res = pow(v2, v1);
#if NCNN_moltenvk
if (op_type == 10) res = afpvec4(atan(vec4(v1), vec4(v2)));
if (op_type == 11) res = afpvec4(atan(vec4(v2), vec4(v1)));
#else
if (op_type == 10) res = atan(v1, v2);
if (op_type == 11) res = atan(v2, v1);
#endif
#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);
......
......@@ -163,6 +163,26 @@ void main()
res[0] = pow(v2[0], v1[0]);
res[1] = pow(v2[1], v1[1]);
}
if (op_type == 10)
{
#if NCNN_moltenvk
res[0] = afpvec4(atan(vec4(v1[0]), vec4(v2[0])));
res[1] = afpvec4(atan(vec4(v1[1]), vec4(v2[1])));
#else
res[0] = atan(v1[0], v2[0]);
res[1] = atan(v1[1], v2[1]);
#endif
}
if (op_type == 11)
{
#if NCNN_moltenvk
res[0] = afpvec4(atan(vec4(v2[0]), vec4(v1[0])));
res[1] = afpvec4(atan(vec4(v2[1]), vec4(v1[1])));
#else
res[0] = atan(v2[0], v1[0]);
res[1] = atan(v2[1], v1[1]);
#endif
}
#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);
......
......@@ -505,4 +505,16 @@ static NCNN_FORCEINLINE __m512 pow512_ps(__m512 a, __m512 b)
return exp512_ps(_mm512_mul_ps(b, log512_ps(a)));
}
static NCNN_FORCEINLINE __m512 atan2512_ps(__m512 a, __m512 b)
{
//TODO avx512 optimize
float tmpx[16];
float tmpy[16];
_mm512_storeu_ps(tmpx, a);
_mm512_storeu_ps(tmpy, b);
for (int i = 0; i < 16; i++)
tmpx[i] = atan2(tmpx[i], tmpy[i]);
return _mm512_loadu_ps(tmpx);
}
#endif // AVX512_MATHFUN_H
......@@ -751,4 +751,16 @@ static NCNN_FORCEINLINE __m256 pow256_ps(__m256 a, __m256 b)
return exp256_ps(_mm256_mul_ps(b, log256_ps(a)));
}
static NCNN_FORCEINLINE __m256 atan2256_ps(__m256 a, __m256 b)
{
//TODO avx optimize
float tmpx[8];
float tmpy[8];
_mm256_storeu_ps(tmpx, a);
_mm256_storeu_ps(tmpy, b);
for (int i = 0; i < 8; i++)
tmpx[i] = atan2(tmpx[i], tmpy[i]);
return _mm256_loadu_ps(tmpx);
}
#endif // AVX_MATHFUN_H
......@@ -1048,6 +1048,58 @@ struct binary_op_rpow
#endif // __SSE2__
};
struct binary_op_atan2
{
float func(const float& x, const float& y) const
{
return (float)atan2(x, y);
}
#if __SSE2__
__m128 func_pack4(const __m128& x, const __m128& y) const
{
return atan2_ps(x, y);
}
#if __AVX__
__m256 func_pack8(const __m256& x, const __m256& y) const
{
return atan2256_ps(x, y);
}
#if __AVX512F__
__m512 func_pack16(const __m512& x, const __m512& y) const
{
return atan2512_ps(x, y);
}
#endif // __AVX512F__
#endif // __AVX__
#endif // __SSE2__
};
struct binary_op_ratan2
{
float func(const float& x, const float& y) const
{
return (float)atan2(y, x);
}
#if __SSE2__
__m128 func_pack4(const __m128& x, const __m128& y) const
{
return atan2_ps(y, x);
}
#if __AVX__
__m256 func_pack8(const __m256& x, const __m256& y) const
{
return atan2256_ps(y, x);
}
#if __AVX512F__
__m512 func_pack16(const __m512& x, const __m512& y) const
{
return atan2512_ps(y, x);
}
#endif // __AVX512F__
#endif // __AVX__
#endif // __SSE2__
};
} // namespace BinaryOp_x86_functor
static int binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Option& opt)
......@@ -1064,6 +1116,8 @@ static int binary_op_scalar(const Mat& a, float b, Mat& c, int op_type, const Op
if (op_type == BinaryOp::Operation_RSUB) return binary_op_scalar<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_scalar<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_scalar<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_scalar<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_scalar<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -1080,9 +1134,11 @@ static int binary_op_no_broadcast(const Mat& a, const Mat& b, Mat& c, int op_typ
if (op_type == BinaryOp::Operation_MAX) return binary_op_no_broadcast<binary_op_max>(a, b, c, opt);
if (op_type == BinaryOp::Operation_MIN) return binary_op_no_broadcast<binary_op_min>(a, b, c, opt);
if (op_type == BinaryOp::Operation_POW) return binary_op_no_broadcast<binary_op_pow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RSUB) return binary_op_no_broadcast<binary_op_sub>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_no_broadcast<binary_op_div>(b, a, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_no_broadcast<binary_op_pow>(b, a, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_no_broadcast<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_no_broadcast<binary_op_atan2>(b, a, c, opt);
// should never reach here
return 0;
......@@ -1117,6 +1173,8 @@ static int binary_op_broadcast_inner(const Mat& a, const Mat& b, Mat& c, int op_
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_inner<binary_op_rsub>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_inner<binary_op_rdiv>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_inner<binary_op_rpow>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_inner<binary_op_atan2>(a, b2, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_inner<binary_op_ratan2>(a, b2, c, opt);
// should never reach here
return 0;
......@@ -1136,6 +1194,8 @@ static int binary_op_broadcast_outer(const Mat& a, const Mat& b, Mat& c, int op_
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_outer<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_outer<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_outer<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_outer<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_outer<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -1155,6 +1215,8 @@ static int binary_op_broadcast_20(const Mat& a, const Mat& b, Mat& c, int op_typ
if (op_type == BinaryOp::Operation_RSUB) return binary_op_broadcast_20<binary_op_rsub>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RDIV) return binary_op_broadcast_20<binary_op_rdiv>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RPOW) return binary_op_broadcast_20<binary_op_rpow>(a, b, c, opt);
if (op_type == BinaryOp::Operation_ATAN2) return binary_op_broadcast_20<binary_op_atan2>(a, b, c, opt);
if (op_type == BinaryOp::Operation_RATAN2) return binary_op_broadcast_20<binary_op_ratan2>(a, b, c, opt);
// should never reach here
return 0;
......@@ -1165,9 +1227,11 @@ static int get_reverse_op_type(int op_type)
if (op_type == BinaryOp::Operation_SUB) return BinaryOp::Operation_RSUB;
if (op_type == BinaryOp::Operation_DIV) return BinaryOp::Operation_RDIV;
if (op_type == BinaryOp::Operation_POW) return BinaryOp::Operation_RPOW;
if (op_type == BinaryOp::Operation_ATAN2) return BinaryOp::Operation_RATAN2;
if (op_type == BinaryOp::Operation_RSUB) return BinaryOp::Operation_SUB;
if (op_type == BinaryOp::Operation_RDIV) return BinaryOp::Operation_DIV;
if (op_type == BinaryOp::Operation_RPOW) return BinaryOp::Operation_POW;
if (op_type == BinaryOp::Operation_RATAN2) return BinaryOp::Operation_ATAN2;
return op_type;
}
......@@ -1239,6 +1303,8 @@ int BinaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_RSUB) return binary_op_scalar_inplace<binary_op_rsub>(bottom_top_blob, b, opt);
if (op_type == Operation_RDIV) return binary_op_scalar_inplace<binary_op_rdiv>(bottom_top_blob, b, opt);
if (op_type == Operation_RPOW) return binary_op_scalar_inplace<binary_op_rpow>(bottom_top_blob, b, opt);
if (op_type == Operation_ATAN2) return binary_op_scalar_inplace<binary_op_atan2>(bottom_top_blob, b, opt);
if (op_type == Operation_RATAN2) return binary_op_scalar_inplace<binary_op_ratan2>(bottom_top_blob, b, opt);
return 0;
}
......
......@@ -738,4 +738,18 @@ static NCNN_FORCEINLINE __m128 pow_ps(__m128 a, __m128 b)
return exp_ps(_mm_mul_ps(b, log_ps(a)));
}
static NCNN_FORCEINLINE __m128 atan2_ps(__m128 a, __m128 b)
{
//TODO sse optimize
float tmpx[4];
float tmpy[4];
_mm_storeu_ps(tmpx, a);
_mm_storeu_ps(tmpy, b);
tmpx[0] = atan2(tmpx[0], tmpy[0]);
tmpx[1] = atan2(tmpx[1], tmpy[1]);
tmpx[2] = atan2(tmpx[2], tmpy[2]);
tmpx[3] = atan2(tmpx[3], tmpy[3]);
return _mm_loadu_ps(tmpx);
}
#endif // SSE_MATHFUN_H
......@@ -15,7 +15,7 @@
#include "layer/binaryop.h"
#include "testutil.h"
#define OP_TYPE_MAX 10
#define OP_TYPE_MAX 12
static int op_type = 0;
......
......@@ -15,7 +15,7 @@
#include "layer/binaryop.h"
#include "testutil.h"
#define OP_TYPE_MAX 10
#define OP_TYPE_MAX 12
static int op_type = 0;
......
......@@ -15,7 +15,7 @@
#include "layer/binaryop.h"
#include "testutil.h"
#define OP_TYPE_MAX 10
#define OP_TYPE_MAX 12
static int op_type = 0;
......@@ -368,7 +368,7 @@ int main()
{
SRAND(7767517);
for (op_type = 6; op_type < OP_TYPE_MAX; op_type++)
for (op_type = 6; op_type < 9; op_type++)
{
int ret = 0
|| test_binaryop_1()
......
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#include "layer/binaryop.h"
#include "testutil.h"
#define OP_TYPE_MAX 12
static int op_type = 0;
static int test_binaryop(const ncnn::Mat& _a, const ncnn::Mat& _b)
{
ncnn::Mat a = _a;
ncnn::Mat b = _b;
if (op_type == 6 || op_type == 9)
{
// value must be positive for pow/rpow
a = a.clone();
b = b.clone();
Randomize(a, 0.001f, 2.f);
Randomize(b, 0.001f, 2.f);
}
if (op_type == 3 || op_type == 8)
{
// value must be positive for div/rdiv
a = a.clone();
b = b.clone();
Randomize(a, 0.1f, 10.f);
Randomize(b, 0.1f, 10.f);
}
ncnn::ParamDict pd;
pd.set(0, op_type);
pd.set(1, 0); // with_scalar
pd.set(2, 0.f); // b
std::vector<ncnn::Mat> weights(0);
std::vector<ncnn::Mat> ab(2);
ab[0] = a;
ab[1] = b;
int ret = test_layer<ncnn::BinaryOp>("BinaryOp", pd, weights, ab);
if (ret != 0)
{
fprintf(stderr, "test_binaryop failed a.dims=%d a=(%d %d %d %d) b.dims=%d b=(%d %d %d %d) op_type=%d\n", a.dims, a.w, a.h, a.d, a.c, b.dims, b.w, b.h, b.d, b.c, op_type);
}
return ret;
}
static int test_binaryop(const ncnn::Mat& _a, float b)
{
ncnn::Mat a = _a;
if (op_type == 6 || op_type == 9)
{
// value must be positive for pow/rpow
Randomize(a, 0.001f, 2.f);
b = RandomFloat(0.001f, 2.f);
}
if (op_type == 3 || op_type == 8)
{
// value must be positive for div/rdiv
a = a.clone();
Randomize(a, 0.1f, 10.f);
}
ncnn::ParamDict pd;
pd.set(0, op_type);
pd.set(1, 1); // with_scalar
pd.set(2, b); // b
std::vector<ncnn::Mat> weights(0);
int ret = test_layer<ncnn::BinaryOp>("BinaryOp", pd, weights, a);
if (ret != 0)
{
fprintf(stderr, "test_binaryop failed a.dims=%d a=(%d %d %d %d) b=%f op_type=%d\n", a.dims, a.w, a.h, a.d, a.c, b, op_type);
}
return ret;
}
// https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting
static int test_binaryop_1()
{
ncnn::Mat a[] = {
RandomMat(31),
RandomMat(28),
RandomMat(24),
RandomMat(32),
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32),
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32),
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};
ncnn::Mat b[] = {
RandomMat(1),
RandomMat(1, 1),
RandomMat(1, 1, 1),
RandomMat(1, 1, 1, 1)
};
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
for (int j = 0; j < sizeof(b) / sizeof(b[0]); j++)
{
int ret = test_binaryop(a[i], b[j]) || test_binaryop(b[j], a[i]);
if (ret != 0)
return ret;
}
int ret = test_binaryop(a[i], 0.2f);
if (ret != 0)
return ret;
}
return 0;
}
static int test_binaryop_2()
{
ncnn::Mat a[] = {
RandomMat(31),
RandomMat(28),
RandomMat(24),
RandomMat(32),
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32),
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32),
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b;
b.create_like(a[i]);
Randomize(b);
int ret = test_binaryop(a[i], b) || test_binaryop(b, a[i]);
if (ret != 0)
return ret;
}
return 0;
}
static int test_binaryop_3()
{
ncnn::Mat a[] = {
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32)
};
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].h);
ncnn::Mat b1(1, a[i].h);
Randomize(b0);
Randomize(b1);
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i]);
if (ret != 0)
return ret;
}
return 0;
}
static int test_binaryop_4()
{
ncnn::Mat a[] = {
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32)
};
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].c);
ncnn::Mat b1(1, 1, a[i].c);
ncnn::Mat b2(a[i].h, a[i].c);
ncnn::Mat b3(1, a[i].h, a[i].c);
Randomize(b0);
Randomize(b1);
Randomize(b2);
Randomize(b3);
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i])
|| test_binaryop(a[i], b2) || test_binaryop(b2, a[i])
|| test_binaryop(a[i], b3) || test_binaryop(b3, a[i]);
if (ret != 0)
return ret;
}
return 0;
}
static int test_binaryop_5()
{
ncnn::Mat a[] = {
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].c);
ncnn::Mat b1(1, 1, 1, a[i].c);
ncnn::Mat b2(a[i].d, a[i].c);
ncnn::Mat b3(1, 1, a[i].d, a[i].c);
ncnn::Mat b4(a[i].h, a[i].d, a[i].c);
ncnn::Mat b5(1, a[i].h, a[i].d, a[i].c);
Randomize(b0);
Randomize(b1);
Randomize(b2);
Randomize(b3);
Randomize(b4);
Randomize(b5);
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i])
|| test_binaryop(a[i], b2) || test_binaryop(b2, a[i])
|| test_binaryop(a[i], b3) || test_binaryop(b3, a[i])
|| test_binaryop(a[i], b4) || test_binaryop(b4, a[i])
|| test_binaryop(a[i], b5) || test_binaryop(b5, a[i]);
if (ret != 0)
return ret;
}
return 0;
}
static int test_binaryop_6()
{
ncnn::Mat a[] = {
RandomMat(13, 31),
RandomMat(14, 28),
RandomMat(15, 24),
RandomMat(16, 32)
};
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1);
Randomize(b0);
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]);
if (ret != 0)
return ret;
}
return 0;
}
static int test_binaryop_7()
{
ncnn::Mat a[] = {
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32)
};
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1, 1);
ncnn::Mat b1(a[i].w, a[i].h, 1);
Randomize(b0);
Randomize(b1);
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i]);
if (ret != 0)
return ret;
}
return 0;
}
static int test_binaryop_8()
{
ncnn::Mat a[] = {
RandomMat(2, 7, 3, 31),
RandomMat(3, 6, 4, 28),
RandomMat(4, 5, 5, 24),
RandomMat(5, 4, 6, 32)
};
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1, 1, 1);
ncnn::Mat b1(a[i].w, a[i].h, 1, 1);
ncnn::Mat b2(a[i].w, a[i].h, a[i].d, 1);
Randomize(b0);
Randomize(b1);
Randomize(b2);
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i])
|| test_binaryop(a[i], b1) || test_binaryop(b1, a[i])
|| test_binaryop(a[i], b2) || test_binaryop(b2, a[i]);
if (ret != 0)
return ret;
}
return 0;
}
static int test_binaryop_9()
{
ncnn::Mat a[] = {
RandomMat(7, 3, 31),
RandomMat(6, 4, 28),
RandomMat(5, 5, 24),
RandomMat(4, 6, 32)
};
for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++)
{
ncnn::Mat b0(a[i].w, 1, a[i].c);
Randomize(b0);
int ret = test_binaryop(a[i], b0) || test_binaryop(b0, a[i]);
if (ret != 0)
return ret;
}
return 0;
}
int main()
{
SRAND(7767517);
for (op_type = 9; op_type < OP_TYPE_MAX; op_type++)
{
int ret = 0
|| test_binaryop_1()
|| test_binaryop_2()
|| test_binaryop_3()
|| test_binaryop_4()
|| test_binaryop_5()
|| test_binaryop_6()
|| test_binaryop_7()
|| test_binaryop_8()
|| test_binaryop_9();
if (ret != 0)
return ret;
}
return 0;
}
......@@ -174,7 +174,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
op_unary->inputs.push_back(op_unary_in);
op_unary->outputs.push_back(op_unary_out);
}
else if (t == "add" || t == "sub" || t == "mul" || t == "div" || /*t == "floor_divide" || */ t == "pow")
else if (t == "add" || t == "sub" || t == "mul" || t == "div" || /*t == "floor_divide" || */ t == "pow" || t == "atan2")
{
std::string a = exprstack.top();
exprstack.pop();
......@@ -191,11 +191,14 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
if (t == "mul") op_binary->params["0"] = 2;
if (t == "div") op_binary->params["0"] = 3;
if (t == "pow") op_binary->params["0"] = 6;
if (t == "atan2") op_binary->params["0"] = 10;
if (token_is_literal(a))
{
if (t == "sub") op_binary->params["0"] = 7;
if (t == "div") op_binary->params["0"] = 8;
if (t == "pow") op_binary->params["0"] = 9;
if (t == "atan2") op_binary->params["0"] = 11;
Operand* op_binary_inb = token_is_argument(b) ? op->inputs[std::stoi(b.substr(1))] : graph.get_operand(op->name + "_" + b);
op_binary_inb->consumers.push_back(op_binary);
......
......@@ -161,6 +161,7 @@ pnnx_ncnn_add_test(torch_abs)
pnnx_ncnn_add_test(torch_acos)
pnnx_ncnn_add_test(torch_asin)
pnnx_ncnn_add_test(torch_atan)
pnnx_ncnn_add_test(torch_atan2)
pnnx_ncnn_add_test(torch_ceil)
pnnx_ncnn_add_test(torch_clamp)
pnnx_ncnn_add_test(torch_cos)
......
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x, y, z):
out0 = torch.atan2(x, y)
out1 = torch.atan2(y, y)
out2 = torch.atan2(z, torch.ones_like(z) + 0.5)
return out0, out1, out2
def test():
net = Model()
net.eval()
torch.manual_seed(0)
x = torch.rand(3, 16)
y = torch.rand(3, 16)
z = torch.rand(5, 9, 3)
a = net(x, y, z)
# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_atan2.pt")
# torchscript to pnnx
import os
os.system("../../src/pnnx test_torch_atan2.pt inputshape=[3,16],[3,16],[5,9,3]")
# ncnn inference
import test_torch_atan2_ncnn
b = test_torch_atan2_ncnn.test_inference()
for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True
if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册