diff --git a/paddle/fluid/operators/mode_op.cc b/paddle/fluid/operators/mode_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..90e513cb1cd07f907cad106632e0071325ae3f04 --- /dev/null +++ b/paddle/fluid/operators/mode_op.cc @@ -0,0 +1,155 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/mode_op.h" +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace operators { + +class ModeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "mode"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "mode"); + OP_INOUT_CHECK(ctx->HasOutput("Indices"), "Output", "Indices", "mode"); + + auto input_dims = ctx->GetInputDim("X"); + const int& dim_size = input_dims.size(); + int axis = static_cast(ctx->Attrs().Get("axis")); + PADDLE_ENFORCE_EQ( + (axis < dim_size) && (axis >= (-1 * dim_size)), true, + paddle::platform::errors::InvalidArgument( + "the axis of ModeOp must be [-%d, %d), but you set axis is %d", + dim_size, dim_size, axis)); + PADDLE_ENFORCE_GE(input_dims.size(), 1, + paddle::platform::errors::InvalidArgument( + "input of ModeOp must have >= 1d shape")); + if (axis < 0) axis += dim_size; + bool keepdim = ctx->Attrs().Get("keepdim"); + std::vector dimvec; + for (int64_t i = 0; i < axis; i++) { + dimvec.emplace_back(input_dims[i]); + } + if (keepdim) { + dimvec.emplace_back(static_cast(1)); + } + for (int64_t i = axis + 1; i < dim_size; i++) { + dimvec.emplace_back(input_dims[i]); + } + framework::DDim dims = framework::make_ddim(dimvec); + PADDLE_ENFORCE_GE(input_dims.size(), 1, platform::errors::InvalidArgument( + "input shape should >= 1d")); + ctx->SetOutputDim("Out", dims); + ctx->SetOutputDim("Indices", dims); + ctx->ShareLoD("X", "Out"); + ctx->ShareLoD("X", "Indices"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library_{framework::LibraryType::kPlain}; + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context(), + layout_, library_); + } +}; + +class ModeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) The input of Mode op"); + AddOutput("Out", "(Tensor) The output tensor of Topk op"); + AddOutput("Indices", "(Tensor) The indices of Topk elements of input"); + AddAttr("axis", + "the axis to calculate mode values." + "if not set, will calculate on last axis.") + .SetDefault(-1); + AddAttr("keepdim", "Keep the dim that to reduce.").SetDefault(false); + AddComment(R"DOC( +This operator finds the mode of input Tensor. And outputs their values and indices as vectors. +)DOC"); + } +}; + +class ModeOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::InvalidArgument("Input(X) should be not null")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Indices"), true, + platform::errors::InvalidArgument("Input(Indices) should be not null")); + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, + platform::errors::InvalidArgument( + "Grad Input(Out) should be not null")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput(framework::GradVarName("X")), true, + platform::errors::InvalidArgument("Grad Output(X) should be not null")); + + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +template +class ModeGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("mode_grad"); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetInput("X", this->Input("X")); + op->SetInput("Indices", this->Output("Indices")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(mode, ops::ModeOp, ops::ModeOpMaker, + ops::ModeGradOpMaker, + ops::ModeGradOpMaker); +REGISTER_OP_CPU_KERNEL(mode, + ops::ModeCPUKernel, + ops::ModeCPUKernel, + ops::ModeCPUKernel, + ops::ModeCPUKernel); + +REGISTER_OPERATOR(mode_grad, ops::ModeOpGrad); +REGISTER_OP_CPU_KERNEL( + mode_grad, ops::ModeGradCPUKernel, + ops::ModeGradCPUKernel, + ops::ModeGradCPUKernel, + ops::ModeGradCPUKernel); diff --git a/paddle/fluid/operators/mode_op.cu b/paddle/fluid/operators/mode_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..b42bdb548216e797e806e0bc1e0bacc8e442d320 --- /dev/null +++ b/paddle/fluid/operators/mode_op.cu @@ -0,0 +1,233 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mode_op.h" +#include "paddle/fluid/operators/top_k_function_cuda.h" +#include "paddle/fluid/operators/top_k_v2_op.h" + +namespace paddle { +namespace operators { + +int ComputeBlockSize(int col) { + if (col > 512) + return 1024; + else if (col > 256 && col <= 512) + return 512; + else if (col > 128 && col <= 256) + return 256; + else if (col > 64 && col <= 128) + return 128; + else + return 64; +} + +template +void getModebySort(const platform::CUDADeviceContext& ctx, + const framework::Tensor* input_tensor, + const int64_t num_cols, const int64_t num_rows, + T* out_tensor, int64_t* indices_tensor) { + framework::Tensor input_tmp; + framework::TensorCopy(*input_tensor, ctx.GetPlace(), &input_tmp); + T* input_tmp_data = input_tmp.mutable_data(ctx.GetPlace()); + input_tmp.Resize(framework::make_ddim({num_rows, num_cols})); + thrust::device_ptr out_tensor_ptr(out_tensor); + thrust::device_ptr indices_tensor_ptr(indices_tensor); + + for (int64_t i = 0; i < num_rows; ++i) { + T* begin = input_tmp_data + num_cols * i; + T* end = input_tmp_data + num_cols * (i + 1); + thrust::device_vector indices_data(num_cols); + thrust::sequence(thrust::device, indices_data.begin(), + indices_data.begin() + num_cols); + thrust::sort_by_key(thrust::device, begin, end, indices_data.begin()); + int unique = 1 + thrust::inner_product(thrust::device, begin, end - 1, + begin + 1, 0, thrust::plus(), + thrust::not_equal_to()); + thrust::device_vector keys_data(unique); + thrust::device_vector cnts_data(unique); + thrust::reduce_by_key(thrust::device, begin, end, + thrust::constant_iterator(1), keys_data.begin(), + cnts_data.begin()); + auto it = thrust::max_element(thrust::device, cnts_data.begin(), + cnts_data.begin() + unique); + T mode = keys_data[it - cnts_data.begin()]; + int64_t counts = cnts_data[it - cnts_data.begin()]; + auto pos = thrust::find(thrust::device, begin, end, mode); + int64_t index = indices_data[pos - begin + counts - 1]; + out_tensor_ptr[i] = static_cast(mode); + indices_tensor_ptr[i] = static_cast(index); + } +} + +template +class ModeOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::InvalidArgument( + "It must use CUDAPlace, you must check your device set.")); + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto* indices = ctx.Output("Indices"); + int axis = static_cast(ctx.Attr("axis")); + bool keepdim = static_cast(ctx.Attr("keepdim")); + + // get the input dims + const auto& in_dims = input->dims(); + // calcluate the real axis + if (axis < 0) axis += in_dims.size(); + + auto out_dims = output->dims(); + + const T* input_data = input->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); + + if (axis == in_dims.size() - 1) { + const int64_t& input_height = framework::product( + framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t& input_width = in_dims[in_dims.size() - 1]; + const auto& dev_ctx = ctx.cuda_device_context(); + getModebySort(dev_ctx, input, input_width, input_height, output_data, + indices_data); + } else { + std::vector trans_axis; + for (int i = 0; i < axis; i++) { + trans_axis.emplace_back(i); + } + trans_axis.emplace_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans_axis.emplace_back(i); + } + trans_axis.emplace_back(axis); + + if (!keepdim) { + std::vector tmp_out_shape; + for (int i = 0; i < axis; i++) { + tmp_out_shape.emplace_back(in_dims[i]); + } + tmp_out_shape.emplace_back(1); + for (int i = axis + 1; i < in_dims.size(); i++) { + tmp_out_shape.emplace_back(in_dims[i]); + } + framework::DDim tmp_out_dim = framework::make_ddim(tmp_out_shape); + output->Resize(tmp_out_dim); + indices->Resize(tmp_out_dim); + } + + framework::DDim trans_shape(in_dims); + framework::DDim trans_out_shape(in_dims); + for (int i = 0; i < trans_axis.size(); i++) { + trans_shape[i] = in_dims[trans_axis[i]]; + trans_out_shape[i] = in_dims[trans_axis[i]]; + } + trans_out_shape[in_dims.size() - 1] = 1; + + // second step, tranpose the input + framework::Tensor trans_input; + trans_input.mutable_data(trans_shape, ctx.GetPlace()); + int ndims = trans_axis.size(); + const auto& dev_ctx = ctx.cuda_device_context(); + TransCompute(ndims, dev_ctx, *input, + &trans_input, trans_axis); + framework::Tensor trans_ind; + int64_t* trans_ind_data = + trans_ind.mutable_data(trans_out_shape, ctx.GetPlace()); + framework::Tensor trans_out; + T* trans_out_data = + trans_out.mutable_data(trans_out_shape, ctx.GetPlace()); + + const int64_t input_height = framework::product( + framework::slice_ddim(trans_shape, 0, trans_shape.size() - 1)); + const int64_t input_width = trans_shape[trans_shape.size() - 1]; + getModebySort(dev_ctx, &trans_input, input_width, input_height, + trans_out_data, trans_ind_data); + // last step, tranpose back the indices and output + TransCompute( + ndims, dev_ctx, trans_ind, indices, trans_axis); + TransCompute(ndims, dev_ctx, trans_out, + output, trans_axis); + if (!keepdim) { + output->Resize(out_dims); + indices->Resize(out_dims); + } + } + } +}; + +template +class ModeOpGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(context.GetPlace()), true, + platform::errors::InvalidArgument( + "It must use CUDAPlace, you must check your device set.")); + auto* x = context.Input("X"); + auto* out_grad = + context.Input(framework::GradVarName("Out")); + auto* indices = context.Input("Indices"); + auto* x_grad = + context.Output(framework::GradVarName("X")); + int axis = context.Attr("axis"); + + const auto& in_dims = x->dims(); + auto out_dims = indices->dims(); + + if (axis < 0) axis += in_dims.size(); + // allocate the cuda memory for the x_grad + T* x_grad_data = x_grad->mutable_data(context.GetPlace()); + const T* out_grad_data = out_grad->data(); + const int64_t* indices_data = indices->data(); + + int pre, n, post; + GetDims(in_dims, axis, &pre, &n, &post); + + // calcluate the block and grid num + auto& dev_ctx = context.cuda_device_context(); + int block_size = ComputeBlockSize(post); + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(((max_threads - 1) / block_size + 1), 1); + int grid_size = std::min(max_blocks, pre); + AssignGradWithAxis<<>>( + out_grad_data, indices_data, x_grad_data, pre, post, n, 1); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + mode, ops::ModeOpCUDAKernel, + ops::ModeOpCUDAKernel, + ops::ModeOpCUDAKernel, + ops::ModeOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + mode_grad, + ops::ModeOpGradCUDAKernel, + ops::ModeOpGradCUDAKernel, + ops::ModeOpGradCUDAKernel, + ops::ModeOpGradCUDAKernel); diff --git a/paddle/fluid/operators/mode_op.h b/paddle/fluid/operators/mode_op.h new file mode 100644 index 0000000000000000000000000000000000000000..dac0ff9279c09b5b1e623ae8969155ab3ceb2b8b --- /dev/null +++ b/paddle/fluid/operators/mode_op.h @@ -0,0 +1,317 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/transpose_op.h" + +namespace paddle { +namespace operators { + +template +static void getMode(Type input_height, Type input_width, int input_dim, + const framework::Tensor* input, T* t_out, Type* t_indices) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (Type i = 0; i < input_height; ++i) { + std::vector> col_vec; + col_vec.reserve(input_width); + if (input_dim == 1) { + auto e_input = framework::EigenVector::Flatten(*input); + for (Type j = 0; j < input_width; ++j) { + col_vec.emplace_back(std::pair(e_input(j), j)); + } + } else { + auto e_input = framework::EigenMatrix::Reshape(*input, input_dim - 1); + for (Type j = 0; j < input_width; ++j) { + col_vec.emplace_back(std::pair(e_input(i, j), j)); + } + } + std::sort(col_vec.begin(), col_vec.end(), + [](const std::pair& l, const std::pair& r) { + return (!std::isnan(static_cast(l.first)) && + std::isnan(static_cast(r.first))) || + (l.first < r.first); + }); + T mode = 0; + int64_t indice = 0; + int64_t cur_freq = 0; + int64_t max_freq = 0; + for (int64_t i = 0; i < input_width; ++i) { + ++cur_freq; + if (i == input_width - 1 || (col_vec[i + 1].first != col_vec[i].first)) { + if (cur_freq > max_freq) { + max_freq = cur_freq; + mode = col_vec[i].first; + indice = col_vec[i].second; + } + cur_freq = 0; + } + } + t_out[i] = mode; + t_indices[i] = indice; + } +} + +template +static void ModeAssign(const Type& input_height, const Type& input_width, + const int& input_dim, const framework::Tensor* input, + const framework::Tensor* indices, T* output_data) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (Type i = 0; i < input_height; ++i) { + if (input_dim == 1) { + auto e_input = framework::EigenVector::Flatten(*input); + auto e_indices = framework::EigenVector::Flatten(*indices); + output_data[i * input_width + e_indices(0)] = e_input(0); + } else { + auto e_input = framework::EigenMatrix::Reshape(*input, input_dim - 1); + auto e_indices = + framework::EigenMatrix::Reshape(*indices, input_dim - 1); + output_data[i * input_width + e_indices(i, 0)] = e_input(i, 0); + } + } +} + +template +class ModeCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("X"); + auto* output = context.Output("Out"); + auto* indices = context.Output("Indices"); + const auto& in_dims = input->dims(); + bool keepdim = static_cast(context.Attr("keepdim")); + + // axis < 0, cacluate the real axis + int axis = static_cast(context.Attr("axis")); + if (axis < 0) axis += in_dims.size(); + + T* output_data = output->mutable_data(context.GetPlace()); + int64_t* indices_data = indices->mutable_data(context.GetPlace()); + auto out_dims = output->dims(); + // if axis is not the last dim, transpose it to the last dim, do the + // calculation, + // then tranpose it back to orginal axis. + if (axis == in_dims.size() - 1) { + const int64_t& input_height = framework::product( + framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t& input_width = in_dims[in_dims.size() - 1]; + getMode(input_height, input_width, in_dims.size(), input, + output_data, indices_data); + } else { + std::vector trans_axis; + for (int i = 0; i < axis; i++) { + trans_axis.emplace_back(i); + } + trans_axis.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans_axis.emplace_back(i); + } + trans_axis.emplace_back(axis); + + if (!keepdim) { + std::vector tmp_out_shape; + for (int i = 0; i < axis; i++) { + tmp_out_shape.emplace_back(in_dims[i]); + } + tmp_out_shape.emplace_back(1); + for (int i = axis + 1; i < in_dims.size(); i++) { + tmp_out_shape.emplace_back(in_dims[i]); + } + framework::DDim tmp_out_dim = framework::make_ddim(tmp_out_shape); + output->Resize(tmp_out_dim); + indices->Resize(tmp_out_dim); + } + + // get the trans input_dims, out_dims + framework::DDim trans_shape(in_dims); + framework::DDim trans_out_shape(in_dims); + + for (size_t i = 0; i < trans_axis.size(); i++) { + trans_shape[i] = in_dims[trans_axis[i]]; + trans_out_shape[i] = in_dims[trans_axis[i]]; + } + trans_out_shape[in_dims.size() - 1] = 1; + + framework::Tensor trans_input; + trans_input.mutable_data(trans_shape, context.GetPlace()); + int ndims = trans_axis.size(); + auto& dev_context = + context.template device_context(); + + // transpose the input value + TransCompute(ndims, dev_context, *input, + &trans_input, trans_axis); + + const int64_t input_height = framework::product( + framework::slice_ddim(trans_shape, 0, trans_shape.size() - 1)); + const int64_t input_width = trans_shape[trans_shape.size() - 1]; + framework::Tensor tmp_out; + T* t_out = tmp_out.mutable_data(trans_out_shape, context.GetPlace()); + framework::Tensor tmp_indices; + auto* t_ind = tmp_indices.mutable_data(trans_out_shape, + context.GetPlace()); + + getMode(input_height, input_width, in_dims.size(), + &trans_input, t_out, t_ind); + // transpose back + TransCompute( + ndims, dev_context, tmp_indices, indices, trans_axis); + TransCompute(ndims, dev_context, tmp_out, + output, trans_axis); + if (!keepdim) { + output->Resize(out_dims); + indices->Resize(out_dims); + } + } + } +}; + +template +class ModeGradCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out_grad = + context.Input(framework::GradVarName("Out")); + auto* indices = context.Input("Indices"); + auto* x_grad = + context.Output(framework::GradVarName("X")); + int axis = static_cast(context.Attr("axis")); + bool keepdim = static_cast(context.Attr("keepdim")); + + auto in_dims = x->dims(); + auto out_dims = indices->dims(); + + // axis < 0, get the real axis + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + + if (!keepdim) { + std::vector tmp_out_shape; + for (int i = 0; i < axis; i++) { + tmp_out_shape.emplace_back(out_dims[i]); + } + tmp_out_shape.emplace_back(1); + for (int i = axis + 1; i < in_dims.size(); i++) { + tmp_out_shape.emplace_back(out_dims[i - 1]); + } + out_dims = framework::make_ddim(tmp_out_shape); + } + T* x_grad_data = x_grad->mutable_data(context.GetPlace()); + if (axis == in_dims.size() - 1) { + // allocate the memory for the input_grad + // assign the out_grad to input_grad directly + const int64_t input_height = framework::product( + framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t input_width = in_dims[in_dims.size() - 1]; + + // init the output grad with 0, because some input elements has no grad + memset(x_grad_data, 0, x_grad->numel() * sizeof(T)); + // Assign the output_grad to input_grad + if (keepdim) { + ModeAssign(input_height, input_width, in_dims.size(), out_grad, indices, + x_grad_data); + } else { + auto& dev_context = + context.template device_context(); + framework::Tensor out_grad_tmp; + framework::Tensor indices_tmp; + out_grad_tmp.mutable_data(out_grad->dims(), dev_context.GetPlace()); + indices_tmp.mutable_data(indices->dims(), + dev_context.GetPlace()); + framework::TensorCopy(*out_grad, dev_context.GetPlace(), dev_context, + &out_grad_tmp); + framework::TensorCopy(*indices, dev_context.GetPlace(), dev_context, + &indices_tmp); + out_grad_tmp.Resize(out_dims); + indices_tmp.Resize(out_dims); + ModeAssign(input_height, input_width, in_dims.size(), &out_grad_tmp, + &indices_tmp, x_grad_data); + } + } else { + // can not assign grad to input_grad, must do the transpose + std::vector trans_axis; + for (int i = 0; i < axis; i++) { + trans_axis.emplace_back(i); + } + trans_axis.emplace_back(out_dims.size() - 1); + for (int i = axis + 1; i < out_dims.size() - 1; i++) { + trans_axis.emplace_back(i); + } + trans_axis.emplace_back(axis); + framework::DDim trans_shape(out_dims); + framework::DDim trans_in_shape(in_dims); + for (size_t i = 0; i < trans_axis.size(); i++) { + trans_shape[i] = out_dims[trans_axis[i]]; + trans_in_shape[i] = in_dims[trans_axis[i]]; + } + // transpose the out_grad, indices + framework::Tensor trans_dO; + trans_dO.mutable_data(trans_shape, context.GetPlace()); + framework::Tensor trans_ind; + trans_ind.mutable_data(trans_shape, context.GetPlace()); + int ndims = trans_axis.size(); + auto& dev_context = + context.template device_context(); + + if (keepdim) { + // Do transpose + TransCompute( + ndims, dev_context, *out_grad, &trans_dO, trans_axis); + TransCompute( + ndims, dev_context, *indices, &trans_ind, trans_axis); + } else { + framework::Tensor out_grad_tmp; + framework::Tensor indices_tmp; + out_grad_tmp.mutable_data(out_grad->dims(), dev_context.GetPlace()); + indices_tmp.mutable_data(indices->dims(), + dev_context.GetPlace()); + framework::TensorCopy(*out_grad, dev_context.GetPlace(), dev_context, + &out_grad_tmp); + framework::TensorCopy(*indices, dev_context.GetPlace(), dev_context, + &indices_tmp); + out_grad_tmp.Resize(out_dims); + indices_tmp.Resize(out_dims); + // Do transpose + TransCompute( + ndims, dev_context, out_grad_tmp, &trans_dO, trans_axis); + TransCompute( + ndims, dev_context, indices_tmp, &trans_ind, trans_axis); + } + const int64_t input_height = framework::product( + framework::slice_ddim(trans_in_shape, 0, trans_in_shape.size() - 1)); + const int64_t input_width = trans_in_shape[trans_in_shape.size() - 1]; + + // Assign the out_grad to tranpose input_grad + framework::Tensor tmp_out; + T* t_out = tmp_out.mutable_data(trans_in_shape, context.GetPlace()); + memset(t_out, 0, x_grad->numel() * sizeof(T)); + + ModeAssign(input_height, input_width, in_dims.size(), + &trans_dO, &trans_ind, t_out); + + // Transpose back + TransCompute(ndims, dev_context, tmp_out, + x_grad, trans_axis); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index b2effed3c9c11731549a9ffbe8772a18112b64d9..67c514b56b901dc67e3ac2ada321b335b6e9255b 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -274,6 +274,7 @@ from .tensor.search import where # noqa: F401 from .tensor.search import index_select # noqa: F401 from .tensor.search import nonzero # noqa: F401 from .tensor.search import sort # noqa: F401 +from .tensor.search import mode # noqa: F401 from .tensor.to_string import set_printoptions # noqa: F401 @@ -400,6 +401,7 @@ __all__ = [ # noqa 'cos', 'tan', 'mean', + 'mode', 'mv', 'in_dynamic_mode', 'min', diff --git a/python/paddle/fluid/tests/unittests/test_mode_op.py b/python/paddle/fluid/tests/unittests/test_mode_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0458f2e255fd981aae81b7da88eecd8f4598d3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_mode_op.py @@ -0,0 +1,178 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid + + +def _mode1D(a): + sorted_inds = np.argsort(a, kind='stable') + sorted_array = a[sorted_inds] + max_freq = 0 + cur_freq = 0 + mode = -1 + for i in range(len(sorted_array)): + cur_freq += 1 + if i == len(sorted_array) - 1 or sorted_array[i] != sorted_array[i + 1]: + if cur_freq > max_freq: + mode = sorted_array[i] + index = sorted_inds[i] + max_freq = cur_freq + cur_freq = 0 + return mode, index + + +def cal_mode(a, axis, keepdim=False): + if axis < 0: + axis = len(a.shape) + axis + in_dims = list(range(a.ndim)) + a_view = np.transpose(a, in_dims[:axis] + in_dims[axis + 1:] + [axis]) + inds = np.ndindex(a_view.shape[:-1]) + modes = np.empty(a_view.shape[:-1], dtype=a.dtype) + indexes = np.empty(a_view.shape[:-1], dtype=np.int64) + for ind in inds: + modes[ind], indexes[ind] = _mode1D(a_view[ind]) + if keepdim: + newshape = list(a.shape) + newshape[axis] = 1 + modes = modes.reshape(newshape) + indexes = indexes.reshape(newshape) + return modes, indexes + + +class TestModeOp(OpTest): + def init_args(self): + self.axis = 1 + + def setUp(self): + self.op_type = "mode" + self.dtype = np.float64 + np.random.seed(666) + self.input_data = np.random.rand(2, 64, 1) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'axis': self.axis} + output, indices = cal_mode(self.input_data, axis=self.axis) + self.outputs = {'Out': output, 'Indices': indices} + + def test_check_output(self): + paddle.enable_static() + self.check_output() + + def test_check_grad(self): + paddle.enable_static() + self.check_grad(set(['X']), 'Out') + + +class TestModeOpLastdim(OpTest): + def init_args(self): + self.axis = -1 + + def setUp(self): + self.op_type = "mode" + self.dtype = np.float64 + np.random.seed(666) + self.input_data = np.random.rand(2, 1, 1, 2, 30) + self.init_args() + self.inputs = {'X': self.input_data} + self.attrs = {'axis': self.axis} + output, indices = cal_mode(self.input_data, axis=self.axis) + self.outputs = {'Out': output, 'Indices': indices} + + def test_check_output(self): + paddle.enable_static() + self.check_output() + + def test_check_grad(self): + paddle.enable_static() + self.check_grad(set(['X']), 'Out') + + +class TestModeOpKernels(unittest.TestCase): + def setUp(self): + self.axises = [-1, 1] + np.random.seed(666) + self.inputs = np.ceil(np.random.rand(2, 10, 10) * 1000) + + def test_mode_op(self): + def test_cpu_kernel(): + paddle.set_device('cpu') + tensor = paddle.to_tensor(self.inputs) + for axis in self.axises: + value_expect, indice_expect = cal_mode(self.inputs, axis) + v, inds = paddle.mode(tensor, axis) + self.assertTrue(np.allclose(v.numpy(), value_expect)) + + value_expect, indice_expect = cal_mode( + self.inputs, axis, keepdim=True) + v, inds = paddle.mode(tensor, axis, keepdim=True) + self.assertTrue(np.allclose(v.numpy(), value_expect)) + + def test_gpu_kernel(): + paddle.set_device('gpu') + tensor = paddle.to_tensor(self.inputs) + for axis in self.axises: + value_expect, indice_expect = cal_mode(self.inputs, axis) + v, inds = paddle.mode(tensor, axis) + self.assertTrue(np.allclose(v.numpy(), value_expect)) + + value_expect, indice_expect = cal_mode( + self.inputs, axis, keepdim=True) + v, inds = paddle.mode(tensor, axis, keepdim=True) + self.assertTrue(np.allclose(v.numpy(), value_expect)) + + paddle.disable_static() + test_cpu_kernel() + if fluid.core.is_compiled_with_cuda(): + test_gpu_kernel() + + +class TestModeOpErrors(unittest.TestCase): + def setUp(self): + self.x = paddle.uniform([2, 10, 20, 25], dtype='float32') + + def test_dim_range_error(): + self.x.mode(axis=5) + + self.assertRaises(ValueError, test_dim_range_error) + + +class TestModeOpInStatic(unittest.TestCase): + def setUp(self): + np.random.seed(666) + self.input_data = np.ceil( + np.random.random((2, 10, 10)) * 1000, dtype=np.float64) + + def test_run_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + input_tensor = paddle.static.data( + name="x", shape=[2, 10, 10], dtype="float64") + + result = paddle.mode(input_tensor, axis=1) + expect_value = cal_mode(self.input_data, axis=1)[0] + exe = paddle.static.Executor(paddle.CPUPlace()) + paddle_result = exe.run(feed={"x": self.input_data}, + fetch_list=[result])[0] + self.assertTrue(np.allclose(paddle_result, expect_value)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 11a29eadf9011fae6499f607ab0a738d38f1d94a..ce70181096c7eaf09ab4557651ed3c997ce0b802 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -247,6 +247,8 @@ from .search import nonzero # noqa: F401 from .search import sort # noqa: F401 from .search import index_sample # noqa: F401 from .search import masked_select # noqa: F401 +from .search import mode # noqa: F401 + from .stat import mean # noqa: F401 from .stat import std # noqa: F401 from .stat import var # noqa: F401 @@ -462,6 +464,7 @@ tensor_method_func = [ #noqa 'gcd', 'lcm', 'diff', + "mode", 'lerp', 'lerp_', 'erfinv', diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index f3587aa48ddcba4492d6cca310f6ada2b722da2b..afb8a08665263352af2645b6d21356e62a9b3ccb 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -470,6 +470,59 @@ def sort(x, axis=-1, descending=False, name=None): return out +def mode(x, axis=-1, keepdim=False, name=None): + """ + This OP is used to find values and indices of the modes at the optional axis. + + Args: + x(Tensor): Tensor, an input N-D Tensor with type float32, float64, int32, int64. + axis(int, optional): Axis to compute indices along. The effective range + is [-R, R), where R is x.ndim. when axis < 0, it works the same way + as axis + R. Default is -1. + keepdim(bool, optional): Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + tuple(Tensor), return the values and indices. The value data type is the same as the input `x`. The indices data type is int64. + + Examples: + + .. code-block:: python + + import paddle + + tensor = paddle.to_tensor([[[1,2,2],[2,3,3]],[[0,5,5],[9,9,0]]], dtype=paddle.float32) + res = paddle.mode(tensor, 2) + print(res) + # (Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[2., 3.], + # [5., 9.]]), Tensor(shape=[2, 2], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [[1, 1], + # [1, 0]])) + + """ + if in_dygraph_mode(): + return _C_ops.mode(x, "axis", axis, "keepdim", keepdim) + + helper = LayerHelper("mode", **locals()) + inputs = {"X": [x]} + attrs = {} + attrs['axis'] = axis + attrs['keepdim'] = keepdim + + values = helper.create_variable_for_type_inference(dtype=x.dtype) + indices = helper.create_variable_for_type_inference(dtype="int64") + + helper.append_op( + type="mode", + inputs=inputs, + outputs={"Out": [values], + "Indices": [indices]}, + attrs=attrs) + indices.stop_gradient = True + return values, indices + + def where(condition, x, y, name=None): r""" Return a tensor of elements selected from either $x$ or $y$, depending on $condition$.