提交 2b0811c3 编写于 作者: T tensor-tang

refine vadd jitkernel choice

test=develop
上级 a18c0d42
...@@ -93,6 +93,7 @@ std::vector<int> TestSizes() { ...@@ -93,6 +93,7 @@ std::vector<int> TestSizes() {
template <typename KernelTuples, typename... Args> template <typename KernelTuples, typename... Args>
struct BenchFunc { struct BenchFunc {
// return this function avg time // return this function avg time
// TODO(TJ): clear cache every time
double operator()(const typename KernelTuples::func_type tgt, Args... args) { double operator()(const typename KernelTuples::func_type tgt, Args... args) {
for (int i = 0; i < FLAGS_burning; ++i) { for (int i = 0; i < FLAGS_burning; ++i) {
tgt(args...); tgt(args...);
...@@ -172,6 +173,9 @@ void BenchXYZNKernel() { ...@@ -172,6 +173,9 @@ void BenchXYZNKernel() {
RandomVec<T>(d, y_data); RandomVec<T>(d, y_data);
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(),
y.data<T>(), z_data, d); y.data<T>(), z_data, d);
// test inplace
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), z_data,
z_data, d);
} }
} }
......
...@@ -155,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> { ...@@ -155,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
class name##Creator : public JitCodeCreator<int> { \ class name##Creator : public JitCodeCreator<int> { \
public: \ public: \
bool UseMe(const int& attr) const override { \ bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \ return platform::MayIUse(platform::avx) && attr <= 1024; \
} \ } \
size_t CodeSize(const int& d) const override { \ size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \ return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
......
...@@ -61,6 +61,7 @@ class VXXJitCode : public JitCode { ...@@ -61,6 +61,7 @@ class VXXJitCode : public JitCode {
base += "_Vec"; base += "_Vec";
} }
base += (with_relu_ ? "_Relu" : ""); base += (with_relu_ ? "_Relu" : "");
base += "_D" + std::to_string(num_);
return base.c_str(); return base.c_str();
} }
void genCode() override; void genCode() override;
......
...@@ -139,7 +139,7 @@ bool VMulKernel<float>::UseMe(const int& d) const { ...@@ -139,7 +139,7 @@ bool VMulKernel<float>::UseMe(const int& d) const {
template <> template <>
bool VAddKernel<float>::UseMe(const int& d) const { bool VAddKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512; return platform::MayIUse(platform::avx) && d > 512;
} }
template <> template <>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册