提交 2b0811c3 编写于 作者: T tensor-tang

refine vadd jitkernel choice

test=develop
上级 a18c0d42
......@@ -93,6 +93,7 @@ std::vector<int> TestSizes() {
template <typename KernelTuples, typename... Args>
struct BenchFunc {
// return this function avg time
// TODO(TJ): clear cache every time
double operator()(const typename KernelTuples::func_type tgt, Args... args) {
for (int i = 0; i < FLAGS_burning; ++i) {
tgt(args...);
......@@ -172,6 +173,9 @@ void BenchXYZNKernel() {
RandomVec<T>(d, y_data);
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(),
y.data<T>(), z_data, d);
// test inplace
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data<T>(), z_data,
z_data, d);
}
}
......
......@@ -155,7 +155,7 @@ class NCHW16CMulNCCreator : public JitCodeCreator<int> {
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
return platform::MayIUse(platform::avx) && attr <= 1024; \
} \
size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
......
......@@ -61,6 +61,7 @@ class VXXJitCode : public JitCode {
base += "_Vec";
}
base += (with_relu_ ? "_Relu" : "");
base += "_D" + std::to_string(num_);
return base.c_str();
}
void genCode() override;
......
......@@ -139,7 +139,7 @@ bool VMulKernel<float>::UseMe(const int& d) const {
template <>
bool VAddKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
return platform::MayIUse(platform::avx) && d > 512;
}
template <>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册