// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "paddle/fluid/platform/hostdevice.h" namespace paddle { namespace framework { namespace detail { template struct UnrollFillConstant { template HOSTDEVICE inline static void Run(T *data, T val) { data[kStart] = val; UnrollFillConstant::Run(data, val); } }; template struct UnrollFillConstant { template HOSTDEVICE inline static void Run(T *data, T val) {} }; template struct UnrollAssign { template HOSTDEVICE inline static void Run(const Tin *d1, Tout *d2) { d2[kStart] = static_cast(d1[kStart]); UnrollAssign::Run(d1, d2); } }; template struct UnrollAssign { template HOSTDEVICE inline static void Run(const Tin *d1, Tout *d2) {} }; template struct UnrollVarArgsAssign { template HOSTDEVICE inline static void Run(T *d, T val, Args... args) { static_assert(sizeof...(args) + 1 == kEnd - kStart, "Wrong argument"); d[kStart] = val; UnrollVarArgsAssign::Run(d, args...); } }; template struct UnrollVarArgsAssign { HOSTDEVICE inline static void Run(T *d) {} }; template struct UnrollCompare { template HOSTDEVICE inline static bool Run(const T *d1, const T *d2) { return d1[kStart] == d2[kStart] && UnrollCompare::Run(d1, d2); } }; template struct UnrollCompare { template HOSTDEVICE inline constexpr static bool Run(const T *d1, const T *d2) { return true; } }; template struct UnrollAdd { template HOSTDEVICE inline static void Run(const T *d1, const T *d2, T *d3) { d3[kStart] = d1[kStart] + d2[kStart]; UnrollAdd::Run(d1, d2, d3); } }; template struct UnrollAdd { template HOSTDEVICE inline static void Run(const T *d1, const T *d2, T *d3) {} }; template struct UnrollMul { template HOSTDEVICE inline static void Run(const T *d1, const T *d2, T *d3) { d3[kStart] = d1[kStart] * d2[kStart]; UnrollMul::Run(d1, d2, d3); } }; template struct UnrollMul { template HOSTDEVICE inline static void Run(const T *d1, const T *d2, T *d3) {} }; template struct UnrollProduct { template HOSTDEVICE inline static T Run(const T *d) { return d[kStart] * UnrollProduct::Run(d); } template HOSTDEVICE inline static T Run(const T *d1, const T *d2) { return d1[kStart] * d2[kStart] + UnrollProduct::Run(d1, d2); } }; template struct UnrollProduct { template HOSTDEVICE inline constexpr static T Run(const T *d) { return 1; } template HOSTDEVICE inline constexpr static T Run(const T *d1, const T *d2) { return 0; } }; } // namespace detail template using UnrollFillConstant = detail::UnrollFillConstant<0, N, N == 0>; template using UnrollAssign = detail::UnrollAssign<0, N, N == 0>; template using UnrollVarArgsAssign = detail::UnrollVarArgsAssign; template using UnrollCompare = detail::UnrollCompare<0, N, N == 0>; template using UnrollAdd = detail::UnrollAdd<0, N, N == 0>; template using UnrollMul = detail::UnrollMul<0, N, N == 0>; template using UnrollProduct = detail::UnrollProduct<0, N, N == 0>; } // namespace framework } // namespace paddle