/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { namespace math { using float16 = paddle::platform::float16; template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; #define DEFINE_GPU_TRANS(RANK) \ template struct Transpose; \ template struct Transpose; \ template struct Transpose; \ template struct Transpose; \ template struct Transpose; \ template struct Transpose; DEFINE_GPU_TRANS(1); DEFINE_GPU_TRANS(2); DEFINE_GPU_TRANS(3); DEFINE_GPU_TRANS(4); DEFINE_GPU_TRANS(5); DEFINE_GPU_TRANS(6); struct TensorSetConstantGPU { TensorSetConstantGPU(const platform::DeviceContext& context, framework::Tensor* tensor, float value) : context_(context), tensor_(tensor), value_(value) {} template void apply() const { SetConstant functor; functor(reinterpret_cast(context_), tensor_, static_cast(value_)); } const platform::DeviceContext& context_; framework::Tensor* tensor_; float value_; }; template <> void set_constant_with_place( const platform::DeviceContext& context, framework::Tensor* tensor, float value) { framework::VisitDataType(tensor->type(), TensorSetConstantGPU(context, tensor, value)); } template __global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int width, int num) { T tmp = 1.0 / width; CUDA_KERNEL_LOOP(i, num) { int h = i * tmp; int w = i - h * width; c[i] = a[i] + b[w]; } } template struct RowwiseAdd { void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, const framework::Tensor& vector, framework::Tensor* output) { auto in_dims = input.dims(); auto size = input.numel() / in_dims[0]; PADDLE_ENFORCE_EQ(vector.numel(), size); PADDLE_ENFORCE_EQ(output->dims(), in_dims); int blocks = 512; int grids = (input.numel() + blocks - 1) / blocks; RowwiseAddKernel<<>>( input.data(), vector.data(), output->data(), static_cast(in_dims[1]), static_cast(input.numel())); } }; template struct RowwiseAdd; template struct RowwiseAdd; template struct ColwiseSum; template struct ColwiseSum; template struct ColwiseSum; // template struct ColwiseSum; // The ColwiseSum failed in debug mode, // and only failed for this case. So reimplemented it. template <> void ColwiseSum::operator()( const platform::CUDADeviceContext& context, const framework::Tensor& input, framework::Tensor* vector) { auto in_dims = input.dims(); auto size = input.numel() / in_dims[0]; PADDLE_ENFORCE_EQ(vector->numel(), size); framework::Tensor one; one.mutable_data({in_dims[0]}, context.GetPlace()); SetConstant set; set(context, &one, static_cast(1.0)); GetBlas(context).GEMV( true, static_cast(in_dims[0]), static_cast(in_dims[1]), 1.0, input.data(), one.data(), 0.0, vector->data()); } template struct RowwiseSum; // template struct RowwiseSum; // TODO(zcd): Following ColwiseSum format, need to confirm. // The RowwiseSum failed in debug mode, // and only failed for this case. So reimplemented it. template <> void RowwiseSum::operator()( const platform::CUDADeviceContext& context, const framework::Tensor& input, framework::Tensor* vector) { auto in_dims = input.dims(); auto size = input.numel() / in_dims[0]; PADDLE_ENFORCE_EQ(vector->numel(), in_dims[0]); framework::Tensor one; one.mutable_data({size}, context.GetPlace()); SetConstant set; set(context, &one, static_cast(1.0)); GetBlas(context).GEMV( true, static_cast(in_dims[1]), static_cast(in_dims[0]), 1.0, one.data(), input.data(), 0.0, vector->data()); } template struct RowwiseMean; template struct RowwiseMean; template struct ElementwiseAddTo { void operator()(platform::CUDADeviceContext* ctx, const framework::Tensor& src, framework::Tensor* dst) { auto in = framework::EigenVector::Flatten(src); auto out = framework::EigenVector::Flatten(*dst); auto& place = *(ctx->eigen_device()); out.device(place) = out + in; } }; template struct ElementwiseAddTo; } // namespace math } // namespace operators } // namespace paddle