未验证 提交 10171806 编写于 作者: L limingshu 提交者: GitHub

Support Mod in elementwise system (#33052)

上级 5f198a6e
......@@ -12,13 +12,60 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mod_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T, typename Enable = void>
struct CudaModFunctor {
inline HOSTDEVICE T operator()(const T* args) const {
T res = args[0] % args[1];
// Accoding to #PR26732: in dividen % divsor
// remainder shall have the same sign as divsor.
if ((res != 0) && ((args[1] ^ res) < 0)) res += args[1];
return res;
}
};
template <typename T>
struct CudaModFunctor<
T, typename std::enable_if_t<std::is_floating_point<T>::value>> {
inline HOSTDEVICE T operator()(const T* args) const {
T res = fmod(args[0], args[1]);
// Accoding to #PR26732: in dividen % divsor
// remainder shall have the same sign as divsor.
if ((res != 0) && ((res < 0) != (args[1] < 0))) res += args[1];
return res;
}
};
template <typename T>
class ElementwiseModKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
cuda_ctx, ins, &outs, axis, CudaModFunctor<T>());
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
elementwise_mod, ops::ElementwiseModKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseModKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseModFPKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseModFPKernel<plat::CUDADeviceContext, double>);
ops::ElementwiseModKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseModKernel<plat::CUDADeviceContext, double>);
......@@ -16,7 +16,6 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册