/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/framework/data_type.h" #include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { namespace math { template void SetConstant::operator()(const DeviceContext& context, framework::Tensor* tensor, T num) { auto t = framework::EigenVector::Flatten(*tensor); t.device(*context.eigen_device()) = t.constant(static_cast(num)); } template void Transpose::operator()( const DeviceContext& context, const framework::Tensor& in, framework::Tensor* out, const std::vector& axis) { Eigen::array permute; for (int i = 0; i < Rank; i++) { permute[i] = axis[i]; } auto in_dim = in.dims(); auto out_dim = out->dims(); auto eigen_in = framework::EigenTensor::From(in); auto eigen_out = framework::EigenTensor::From(*out); auto* dev = context.eigen_device(); eigen_out.device(*dev) = eigen_in.shuffle(permute); } template void RowwiseAdd::operator()(const DeviceContext& context, const framework::Tensor& input, const framework::Tensor& vector, framework::Tensor* output) { auto in_dims = input.dims(); auto size = input.numel() / in_dims[0]; PADDLE_ENFORCE_EQ(vector.numel(), size); PADDLE_ENFORCE_EQ(output->dims(), in_dims); auto in = framework::EigenMatrix::From(input); auto vec = framework::EigenMatrix::From(vector); auto out = framework::EigenMatrix::From(*output); Eigen::array shape({{1, static_cast(size)}}); Eigen::array bcast({{static_cast(in_dims[0]), 1}}); out.device(*context.eigen_device()) = in + vec.reshape(shape).broadcast(bcast); } template void ColwiseSum::operator()(const DeviceContext& context, const framework::Tensor& input, framework::Tensor* out) { auto in_dims = input.dims(); auto size = input.numel() / in_dims[0]; PADDLE_ENFORCE_EQ(out->numel(), size); auto in = framework::EigenMatrix::From(input); auto vec = framework::EigenVector::Flatten(*out); vec.device(*context.eigen_device()) = in.sum(Eigen::array({{0}})); } template void RowwiseSum::operator()(const DeviceContext& context, const framework::Tensor& input, framework::Tensor* vector) { auto in_dims = input.dims(); auto size = input.numel() / in_dims[1]; PADDLE_ENFORCE_EQ(vector->numel(), size); auto in = framework::EigenMatrix::From(input); auto vec = framework::EigenMatrix::From(*vector); vec.device(*context.eigen_device()) = in.sum(Eigen::array({{1}})); } // Specialize for CPU, since Eigen implement a general reduce. However, // colwise-sum can be easily implemented. General reduce has a huge overhead in // CPU template class ColwiseSum { public: void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, framework::Tensor* out) { auto& in_dims = input.dims(); auto height = in_dims[0]; auto size = in_dims[1]; PADDLE_ENFORCE_EQ(out->numel(), size); T* out_buf = out->mutable_data(out->place()); const T* in_buf = input.data(); for (size_t i = 0; i < height; ++i) { for (size_t j = 0; j < size; ++j) { if (i == 0) { out_buf[j] = in_buf[i * size + j]; } else { out_buf[j] += in_buf[i * size + j]; } } } } }; } // namespace math } // namespace operators } // namespace paddle