提交 d04515aa 编写于 作者: X Xinqi

assert that the oprant types of binary-expression-lazy-blob core op must be identical


Former-commit-id: 8518e1bdb3bf3bf90c5da94a708243cb0808c637
上级 b96550dd
......@@ -193,8 +193,8 @@ class VarLazyBlob final : public LazyBlobIf<VarLazyBlob<T>> {
template<template<typename> class CoreFunc, typename XT>
class UnaryExpresionLazyBlob final : public LazyBlobIf<UnaryExpresionLazyBlob<CoreFunc, XT>> {
public:
using XDT = typename XT::dtype;
typedef decltype(CoreFunc<XDT>::Invoke(*(const XDT*)nullptr)) dtype;
using T = typename XT::dtype;
typedef decltype(CoreFunc<T>::Invoke(*(const T*)nullptr)) dtype;
OF_DISALLOW_COPY_AND_MOVE(UnaryExpresionLazyBlob);
explicit UnaryExpresionLazyBlob(const XT& x)
......@@ -202,31 +202,32 @@ class UnaryExpresionLazyBlob final : public LazyBlobIf<UnaryExpresionLazyBlob<Co
GetDataType<dtype>::value),
x_(x) {}
inline dtype operator()(int64_t dim0) const { return CoreFunc<XDT>::Invoke(x_(dim0)); }
inline dtype operator()(int64_t dim0) const { return CoreFunc<T>::Invoke(x_(dim0)); }
inline dtype operator()(int64_t dim0, int64_t dim1) const {
return CoreFunc<XDT>::Invoke(x_(dim0, dim1));
return CoreFunc<T>::Invoke(x_(dim0, dim1));
}
inline dtype operator()(int64_t dim0, int64_t dim1, int64_t dim2) const {
return CoreFunc<XDT>::Invoke(x_(dim0, dim1, dim2));
return CoreFunc<T>::Invoke(x_(dim0, dim1, dim2));
}
inline dtype operator()(int64_t dim0, int64_t dim1, int64_t dim2, int64_t dim3) const {
return CoreFunc<XDT>::Invoke(x_(dim0, dim1, dim2, dim3));
return CoreFunc<T>::Invoke(x_(dim0, dim1, dim2, dim3));
}
inline dtype operator()(int64_t dim0, int64_t dim1, int64_t dim2, int64_t dim3,
int64_t dim4) const {
return CoreFunc<XDT>::Invoke(x_(dim0, dim1, dim2, dim3, dim4));
return CoreFunc<T>::Invoke(x_(dim0, dim1, dim2, dim3, dim4));
}
private:
const XT& x_;
};
template<template<typename, typename> class CoreFunc, typename XT, typename YT = XT>
template<template<typename> class CoreFunc, typename XT, typename YT = XT,
typename = typename std::enable_if<
std::is_same<typename XT::dtype, typename YT::dtype>::value>::type>
class BinaryExpresionLazyBlob final : public LazyBlobIf<BinaryExpresionLazyBlob<CoreFunc, XT, YT>> {
public:
using XDT = typename XT::dtype;
using YDT = typename YT::dtype;
typedef decltype(CoreFunc<XDT, YDT>::Invoke(*(const XDT*)nullptr, *(const YDT*)nullptr)) dtype;
using T = typename XT::dtype;
typedef decltype(CoreFunc<T>::Invoke(*(const T*)nullptr, *(const T*)nullptr)) dtype;
OF_DISALLOW_COPY_AND_MOVE(BinaryExpresionLazyBlob);
BinaryExpresionLazyBlob(const XT& x, const YT& y)
......@@ -237,22 +238,19 @@ class BinaryExpresionLazyBlob final : public LazyBlobIf<BinaryExpresionLazyBlob<
CHECK(x.shape() == y.shape());
}
inline dtype operator()(int64_t dim0) const {
return CoreFunc<XDT, YDT>::Invoke(x_(dim0), y_(dim0));
}
inline dtype operator()(int64_t dim0) const { return CoreFunc<T>::Invoke(x_(dim0), y_(dim0)); }
inline dtype operator()(int64_t dim0, int64_t dim1) const {
return CoreFunc<XDT, YDT>::Invoke(x_(dim0, dim1), y_(dim0, dim1));
return CoreFunc<T>::Invoke(x_(dim0, dim1), y_(dim0, dim1));
}
inline dtype operator()(int64_t dim0, int64_t dim1, int64_t dim2) const {
return CoreFunc<XDT, YDT>::Invoke(x_(dim0, dim1, dim2), y_(dim0, dim1, dim2));
return CoreFunc<T>::Invoke(x_(dim0, dim1, dim2), y_(dim0, dim1, dim2));
}
inline dtype operator()(int64_t dim0, int64_t dim1, int64_t dim2, int64_t dim3) const {
return CoreFunc<XDT, YDT>::Invoke(x_(dim0, dim1, dim2, dim3), y_(dim0, dim1, dim2, dim3));
return CoreFunc<T>::Invoke(x_(dim0, dim1, dim2, dim3), y_(dim0, dim1, dim2, dim3));
}
inline dtype operator()(int64_t dim0, int64_t dim1, int64_t dim2, int64_t dim3,
int64_t dim4) const {
return CoreFunc<XDT, YDT>::Invoke(x_(dim0, dim1, dim2, dim3, dim4),
y_(dim0, dim1, dim2, dim3, dim4));
return CoreFunc<T>::Invoke(x_(dim0, dim1, dim2, dim3, dim4), y_(dim0, dim1, dim2, dim3, dim4));
}
private:
......@@ -263,8 +261,7 @@ class BinaryExpresionLazyBlob final : public LazyBlobIf<BinaryExpresionLazyBlob<
template<typename XT>
class BroadcastLazyBlob final : public LazyBlobIf<BroadcastLazyBlob<XT>> {
public:
using XDT = typename XT::dtype;
typedef XDT dtype;
typedef typename XT::dtype dtype;
OF_DISALLOW_MOVE(BroadcastLazyBlob);
BroadcastLazyBlob(const BroadcastLazyBlob<XT>&) = default;
......@@ -355,10 +352,10 @@ class LazyBlobVarBuilder final {
OF_PP_MAKE_TUPLE_SEQ(LogicalAnd, &&, bool) \
OF_PP_MAKE_TUPLE_SEQ(LogicalOr, &&, bool)
#define DECLARE_LAZY_BLOB_BINARY_CORE(name, op, ret_type) \
template<typename T, typename YDT = T> \
struct LazyBlobCore##name final { \
static inline ret_type Invoke(const T x, const YDT y) { return x op y; } \
#define DECLARE_LAZY_BLOB_BINARY_CORE(name, op, ret_type) \
template<typename T> \
struct LazyBlobCore##name final { \
static inline ret_type Invoke(const T x, const T y) { return x op y; } \
};
OF_PP_FOR_EACH_TUPLE(DECLARE_LAZY_BLOB_BINARY_CORE, LAZY_BLOB_BINARY_CORE_OP_FUNC_SEQ);
......@@ -373,7 +370,7 @@ OF_PP_FOR_EACH_TUPLE(DECLARE_LAZY_BLOB_BINARY_CORE, LAZY_BLOB_BINARY_CORE_OP_FUN
};
OF_PP_FOR_EACH_TUPLE(DECLARE_LAZY_BLOB_UNARY_CORE, LAZY_BLOB_UNARY_CORE_OP_FUNC_SEQ);
template<template<typename, typename> class LazyBlobCoreFunc, typename XT, typename YT = XT>
template<template<typename> class LazyBlobCoreFunc, typename XT, typename YT = XT>
typename std::enable_if<std::is_base_of<LazyBlobNode, XT>::value
&& std::is_base_of<LazyBlobNode, YT>::value,
BinaryExpresionLazyBlob<LazyBlobCoreFunc, XT, YT>>::type&
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册