未验证 提交 21bb4a90 编写于 作者: L Li Xinqi 提交者: GitHub

Bugfix static check (#5935)

* refactor static_check

* fix bug in IsOutArg
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 2e2175b1
......@@ -18,8 +18,8 @@ limitations under the License.
#include <type_traits>
#include <unordered_map>
#include "oneflow/core/common/tuple_hash.h"
#include "oneflow/core/common/static_check.h"
#include "tuple_hash.h"
#include "static_check.h"
namespace oneflow {
......@@ -64,9 +64,8 @@ struct ThreadLocalCopiable<RetT, Arg0> {
}
private:
static void StaticCheckNotOutArg() {
auto* _ = &static_check::ForEachArgsType<static_check::CheckNotOutArg, Arg0>;
}
static_assert(!IsOutArg<Arg0>::value, "");
static_assert(!StaticAny<IsOutArg, Arg0>::value, "");
};
template<typename RetT, typename Arg0, typename Arg1>
......@@ -84,9 +83,7 @@ struct ThreadLocalCopiable<RetT, Arg0, Arg1> {
}
private:
static void StaticCheckNotOutArg() {
auto* _ = &static_check::ForEachArgsType<static_check::CheckNotOutArg, Arg0, Arg1>;
}
static_assert(!StaticAny<IsOutArg, Arg0, Arg1>::value, "");
};
template<typename RetT, typename Arg0, typename Arg1, typename Arg2>
......@@ -107,9 +104,7 @@ struct ThreadLocalCopiable<RetT, Arg0, Arg1, Arg2> {
}
private:
static void StaticCheckNotOutArg() {
auto* _ = &static_check::ForEachArgsType<static_check::CheckNotOutArg, Arg0, Arg1, Arg2>;
}
static_assert(!StaticAny<IsOutArg, Arg0, Arg1, Arg2>::value, "");
};
template<typename RetT, typename Arg0, typename Arg1, typename Arg2, typename Arg3,
......@@ -131,19 +126,14 @@ struct ThreadLocalCopiable<RetT, Arg0, Arg1, Arg2, Arg3, Args...> {
}
private:
static void StaticCheckNotOutArg() {
auto* _ = &static_check::ForEachArgsType<static_check::CheckNotOutArg, Arg0, Arg1, Arg2, Arg3,
Args...>;
}
static_assert(!StaticAny<IsOutArg, Arg0, Arg1, Arg2, Arg3, Args...>::value, "");
};
// for scalar type key.
template<typename... Args>
struct ThreadLocal : public ThreadLocalCopiable<Args...> {
template<typename RetT, typename... Args>
struct ThreadLocal : public ThreadLocalCopiable<RetT, Args...> {
private:
static void StaticCheckIsScalarType() {
auto* _0 = &static_check::ForEachArgsType<static_check::CheckIsScalarType, Args...>;
}
static_assert(StaticAll<IsDecayedScalarType, Args...>::value, "");
};
} // namespace oneflow
......
......@@ -16,36 +16,66 @@ limitations under the License.
#ifndef ONEFLOW_CORE_COMMON_STATIC_CHECK_H_
#define ONEFLOW_CORE_COMMON_STATIC_CHECK_H_
#include "oneflow/core/common/type_traits.h"
#include "type_traits.h"
namespace oneflow {
namespace static_check {
namespace private_details {
template<typename... Args>
void ForEachArgsType(Args... args);
template<template<typename> class Predicator>
struct StaticReduce {
template<typename... Args>
struct All;
template<typename Void>
struct All<Void> {
static_assert(std::is_same<Void, void>::value, "");
static constexpr bool value = true;
};
template<typename Void, typename T, typename... Args>
struct All<Void, T, Args...> {
static constexpr bool value = Predicator<T>::value && All<Void, Args...>::value;
};
template<template<typename> class Checker>
inline void ForEachArgsType() {}
template<typename... Args>
struct Any;
template<typename Void>
struct Any<Void> {
static_assert(std::is_same<Void, void>::value, "");
static constexpr bool value = false;
};
template<typename Void, typename T, typename... Args>
struct Any<Void, T, Args...> {
static constexpr bool value = Predicator<T>::value || Any<Void, Args...>::value;
};
};
} // namespace private_details
template<template<typename> class Checker, typename T, typename... Args>
void ForEachArgsType(T a, Args... args) {
Checker<T> check{};
ForEachArgsType<Checker>(args...);
template<template<typename> class Predicator, typename... Args>
struct StaticAll {
static constexpr bool value =
private_details::StaticReduce<Predicator>::template All<void, Args...>::value;
};
template<typename T>
struct CheckNotOutArg {
static_assert(!(std::is_pointer<T>::value && !std::is_const<T>::value), "");
static_assert(!(std::is_reference<T>::value && !std::is_const<T>::value), "");
template<template<typename> class Predicator, typename... Args>
struct StaticAny {
static constexpr bool value =
private_details::StaticReduce<Predicator>::template Any<void, Args...>::value;
};
template<typename T>
struct CheckIsScalarType {
static_assert(IsScalarType<typename std::decay<T>::type>::value, "");
struct IsOutArg {
static constexpr bool value =
(std::is_reference<T>::value
&& !std::is_const<typename std::remove_reference<T>::type>::value)
|| (std::is_pointer<T>::value
&& !std::is_const<typename std::remove_pointer<T>::type>::value);
};
} // namespace static_check
template<typename T>
struct IsDecayedScalarType {
static constexpr bool value = IsScalarType<typename std::decay<T>::type>::value;
};
} // namespace oneflow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册