diff --git a/paddle/fluid/operators/argsort_op.cc b/paddle/fluid/operators/argsort_op.cc index 1995b7ba048bb71b6c1fe357967e7317230825e8..33a1aafa0fd2e8f277840b0b62db18490fff2a7a 100644 --- a/paddle/fluid/operators/argsort_op.cc +++ b/paddle/fluid/operators/argsort_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/argsort_op.h" +#include namespace paddle { namespace operators { @@ -21,7 +22,7 @@ class ArgsortOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ArgsortOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -49,6 +50,24 @@ class ArgsortOp : public framework::OperatorWithKernel { } }; +class ArgsortGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -83,16 +102,42 @@ Output(Indices) gives the sorted order along the given axis Attr(axis). } }; +template +class ArgsortGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new T()); + op->SetType("argsort_grad"); + op->SetInput("Indices", this->Output("Indices")); + op->SetInput("X", this->Input("X")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + return op; + } +}; + +DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ArgsortGradNoNeedBufferVarInference, "X"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR( - argsort, ops::ArgsortOp, ops::ArgsortOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); +REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker, + ops::ArgsortGradOpMaker, + ops::ArgsortGradOpMaker); +REGISTER_OPERATOR(argsort_grad, ops::ArgsortGradOp, + ops::ArgsortGradNoNeedBufferVarInference); REGISTER_OP_CPU_KERNEL(argsort, ops::ArgsortKernel, ops::ArgsortKernel, ops::ArgsortKernel, ops::ArgsortKernel); +REGISTER_OP_CPU_KERNEL( + argsort_grad, ops::ArgsortGradientKernel, + ops::ArgsortGradientKernel, + ops::ArgsortGradientKernel, + ops::ArgsortGradientKernel); diff --git a/paddle/fluid/operators/argsort_op.cu b/paddle/fluid/operators/argsort_op.cu index 0ea7e3dcb14867002b48e66c322160f8a2c49ba7..006bf559195aa23d08433618d36231aeda228b3d 100644 --- a/paddle/fluid/operators/argsort_op.cu +++ b/paddle/fluid/operators/argsort_op.cu @@ -58,6 +58,19 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) { } } +template +static __global__ void FillGrad(const T* dO, const IndType* indices, T* dX, + IndType num_rows, IndType num_cols) { + int col_id = threadIdx.x; + int row_id = blockIdx.x; + + for (IndType j = row_id; j < num_rows; j += gridDim.x) { + for (IndType i = col_id; i < num_cols; i += blockDim.x) { + dX[j * num_cols + indices[j * num_cols + i]] = dO[j * num_cols + i]; + } + } +} + // Sort by flag descending, True: descending. False: Ascending. // Default is false. template @@ -160,6 +173,35 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input, temp_storage_bytes, cudaGetErrorString(err)); } +template +void ArgFullAssign(const platform::CUDADeviceContext& ctx, const Tensor* dO, + const Tensor* indices, Tensor* dX, const IndType num_rows, + const IndType num_cols) { + auto cu_stream = ctx.stream(); + + auto ComputeBlockSize = [](IndType col) { + if (col > 512) + return 1024; + else if (col > 256 && col <= 512) + return 512; + else if (col > 128 && col <= 256) + return 256; + else if (col > 64 && col <= 128) + return 128; + else + return 64; + }; + + int block_size = ComputeBlockSize(num_cols); + + int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x; + // actually, int num_rows < max_grid_size + int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX; + FillGrad<<>>( + dO->data(), indices->data(), dX->data(), num_rows, + num_cols); +} + template class ArgsortOpCUDAKernel : public framework::OpKernel { public: @@ -234,6 +276,81 @@ class ArgsortOpCUDAKernel : public framework::OpKernel { } }; +template +class ArgsortGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* indices = ctx.Input("Indices"); + auto* dX = ctx.Output(framework::GradVarName("X")); + auto* dO = ctx.Input(framework::GradVarName("Out")); + int axis = ctx.Attr("axis"); + + dX->mutable_data(ctx.GetPlace()); + auto dxt = framework::EigenVector::Flatten(*dX); + auto& place = *ctx.template device_context() + .eigen_device(); + dxt.device(place) = dxt.constant(static_cast(0)); + if (dO->numel() == 0) return; + + auto in_dims = indices->dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + + int64_t numel = indices->numel(); + + // Special case for full sort, speedup ~190x. + if (axis == -1 || axis + 1 == in_dims.size()) { + const int64_t input_height = framework::product( + framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t input_width = in_dims[in_dims.size() - 1]; + const auto& dev_ctx = ctx.cuda_device_context(); + ArgFullAssign(dev_ctx, dO, indices, dX, input_height, + input_width); + } else { + // if not full sort, do transpose first + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + framework::DDim trans_dims(in_dims); + for (int i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + } + + Tensor trans_dO; + trans_dO.mutable_data(trans_dims, ctx.GetPlace()); + Tensor trans_ind; + trans_ind.mutable_data(trans_dims, ctx.GetPlace()); + int ndims = trans.size(); + const auto& dev_ctx = ctx.cuda_device_context(); + // Do transpose + TransCompute(ndims, dev_ctx, *dO, + &trans_dO, trans); + TransCompute( + ndims, dev_ctx, *indices, &trans_ind, trans); + + const int64_t input_height = framework::product( + framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + Tensor tmp_out; + tmp_out.mutable_data(trans_dims, ctx.GetPlace()); + + ArgFullAssign(dev_ctx, &trans_dO, &trans_ind, &tmp_out, + input_height, input_width); + + // transpose back + TransCompute(ndims, dev_ctx, tmp_out, dX, + trans); + return; + } + } +}; + } // namespace operators } // namespace paddle @@ -243,3 +360,9 @@ REGISTER_OP_CUDA_KERNEL( paddle::operators::ArgsortOpCUDAKernel, paddle::operators::ArgsortOpCUDAKernel, paddle::operators::ArgsortOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + argsort_grad, paddle::operators::ArgsortGradOpCUDAKernel, + paddle::operators::ArgsortGradOpCUDAKernel, + paddle::operators::ArgsortGradOpCUDAKernel, + paddle::operators::ArgsortGradOpCUDAKernel, + paddle::operators::ArgsortGradOpCUDAKernel); diff --git a/paddle/fluid/operators/argsort_op.h b/paddle/fluid/operators/argsort_op.h index c48c4c14a83ec3a3f3386639c089908022e0e02c..fb353a8a2367b4cfe1c3e91495e5d409ae4d3772 100644 --- a/paddle/fluid/operators/argsort_op.h +++ b/paddle/fluid/operators/argsort_op.h @@ -68,6 +68,31 @@ static void FullSort(Type input_height, Type input_width, int input_dim, } } } + +template +static void FullAssign(Type input_height, Type input_width, int input_dim, + const framework::Tensor* input, + const framework::Tensor* indices, T* t_out) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (Type i = 0; i < input_height; ++i) { + if (input_dim == 1) { + auto e_input = EigenVector::Flatten(*input); + auto e_indices = EigenVector::Flatten(*indices); + for (Type j = 0; j < input_width; ++j) { + t_out[i * input_width + e_indices(j)] = e_input(e_indices(j)); + } + } else { + auto e_input = EigenMatrix::Reshape(*input, input_dim - 1); + auto e_indices = EigenMatrix::Reshape(*indices, input_dim - 1); + for (Type j = 0; j < input_width; ++j) { + t_out[i * input_width + e_indices(i, j)] = e_input(i, e_indices(i, j)); + } + } + } +} + template class ArgsortKernel : public framework::OpKernel { public: @@ -142,5 +167,77 @@ class ArgsortKernel : public framework::OpKernel { } }; +template +class ArgsortGradientKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* indices = ctx.Input("Indices"); + auto* dX = ctx.Output(framework::GradVarName("X")); + auto* dO = ctx.Input(framework::GradVarName("Out")); + int axis = ctx.Attr("axis"); + + auto in_dims = indices->dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + + dX->mutable_data(ctx.GetPlace()); + auto dxt = framework::EigenVector::Flatten(*dX); + auto& place = *ctx.template device_context() + .eigen_device(); + dxt.device(place) = dxt.constant(static_cast(0)); + if (dO->numel() == 0) return; + + // Do full assign + if (axis == -1 || axis + 1 == in_dims.size()) { + const int64_t input_height = framework::product( + framework::slice_ddim(in_dims, 0, in_dims.size() - 1)); + const int64_t input_width = in_dims[in_dims.size() - 1]; + + FullAssign(input_height, input_width, in_dims.size(), dO, + indices, dX->data()); + } else { + // If not full assign do transpose + std::vector trans; + for (int i = 0; i < axis; i++) { + trans.push_back(i); + } + trans.push_back(in_dims.size() - 1); + for (int i = axis + 1; i < in_dims.size() - 1; i++) { + trans.push_back(i); + } + trans.push_back(axis); + framework::DDim trans_dims(in_dims); + for (size_t i = 0; i < trans.size(); i++) { + trans_dims[i] = in_dims[trans[i]]; + } + + Tensor trans_dO; + trans_dO.mutable_data(trans_dims, ctx.GetPlace()); + Tensor trans_ind; + trans_ind.mutable_data(trans_dims, ctx.GetPlace()); + int ndims = trans.size(); + auto& dev_ctx = ctx.template device_context(); + // Do transpose + TransCompute(ndims, dev_ctx, *dO, + &trans_dO, trans); + TransCompute( + ndims, dev_ctx, *indices, &trans_ind, trans); + + const int64_t input_height = framework::product( + framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); + const int64_t input_width = trans_dims[trans_dims.size() - 1]; + + Tensor tmp_out; + T* t_out = tmp_out.mutable_data(trans_dims, ctx.GetPlace()); + + FullAssign(input_height, input_width, in_dims.size(), + &trans_dO, &trans_ind, t_out); + + // transpose back + TransCompute(ndims, dev_ctx, tmp_out, dX, + trans); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_argsort_op.py b/python/paddle/fluid/tests/unittests/test_argsort_op.py old mode 100755 new mode 100644 index 89ff5d7101a9aeaa529acf26f00670d29665ebdf..44cd34879a69fc5df857e8202f57bcd4f312fd87 --- a/python/paddle/fluid/tests/unittests/test_argsort_op.py +++ b/python/paddle/fluid/tests/unittests/test_argsort_op.py @@ -48,7 +48,7 @@ class TestArgsortOp(OpTest): self.axis = -1 def init_datatype(self): - self.dtype = "float32" + self.dtype = "float64" def init_direction(self): self.descending = False @@ -56,6 +56,9 @@ class TestArgsortOp(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(['X'], 'Out') + class TestArgsortOpAxis0(TestArgsortOp): def init_axis(self): @@ -146,5 +149,18 @@ class TestArgsortOpDescendingAxisNeg2(TestArgsortOpAxisNeg2): self.descending = True +class TestArgsortOpFP32Axis(TestArgsortOp): + def init_datatype(self): + self.dtype = "float32" + + +class TestArgsortOpFP32DescendingAxis(TestArgsortOp): + def init_datatype(self): + self.dtype = "float32" + + def init_direction(self): + self.descending = True + + if __name__ == "__main__": unittest.main()