From 8f035fb63736cba3541d559e416ac062ed57e459 Mon Sep 17 00:00:00 2001 From: Jiawei Wang Date: Mon, 17 Feb 2020 09:52:06 +0800 Subject: [PATCH] Add TopK Op Grad CPU&GPU Kernel test=develop (#22628) * Add TopK Op Grad CPU&GPU Kernel test=develop * Add TopK Op Grad, modify grad op maker test=develop * Add TopK Op Grad, modify grad op maker test=develop * Add TopK Op Grad, modify PADDLE_ENFORCE test=develop * Add TopK Op Grad, modify PADDLE_THROW test=develop * Add TopK Op Grad, modify unittest test=develop * fix ngraph top k op unittest test=develop --- paddle/fluid/operators/top_k_op.cc | 71 +++++++++--- paddle/fluid/operators/top_k_op.cu | 83 ++++++++++++-- paddle/fluid/operators/top_k_op.h | 30 ++++++ .../unittests/ngraph/test_top_k_ngraph_op.py | 2 +- .../fluid/tests/unittests/test_top_k_op.py | 101 +----------------- 5 files changed, 171 insertions(+), 116 deletions(-) diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index c18ec5d418..91184c1ed0 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/top_k_op.h" +#include namespace paddle { namespace operators { @@ -42,11 +43,6 @@ class TopkOp : public framework::OperatorWithKernel { framework::DDim dims = input_dims; dims[dims.size() - 1] = k; - // If has K as tensor, set k=-1 as not know real size at this time. - if (ctx->HasInput("K")) { - dims[dims.size() - 1] = -1; - } - ctx->SetOutputDim("Out", dims); ctx->SetOutputDim("Indices", dims); ctx->ShareLoD("X", "Out"); @@ -89,16 +85,67 @@ For matrices, this operator computes the top k entries in each row. )DOC"); } }; +class TopkOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::InvalidArgument("Input(X) should be not null")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Indices"), true, + platform::errors::InvalidArgument("Input(Indices) should be not null")); + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true, + platform::errors::InvalidArgument( + "Grad Input(Out) should be not null")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput(framework::GradVarName("X")), true, + platform::errors::InvalidArgument("Grad Output(X) should be not null")); + + auto x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +template +class TopkGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new T()); + op->SetType("top_k_grad"); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetInput("X", this->Input("X")); + op->SetInput("Indices", this->Output("Indices")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + return op; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR( - top_k, ops::TopkOp, ops::TopkOpMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); +REGISTER_OPERATOR(top_k, ops::TopkOp, ops::TopkOpMaker, + ops::TopkGradOpMaker, + ops::TopkGradOpMaker); + +REGISTER_OPERATOR(top_k_grad, ops::TopkOpGrad); + REGISTER_OP_CPU_KERNEL(top_k, ops::TopkKernel, - ops::TopkKernel, - ops::TopkKernel, - ops::TopkKernel); + ops::TopkKernel); + +REGISTER_OP_CPU_KERNEL(top_k_grad, + ops::TopkGradKernel, + ops::TopkGradKernel); diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index c56716f8ed..82ecc2887b 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "cub/cub.cuh" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/float16.h" - // set cub base traits in order to handle float16 namespace cub { template <> @@ -300,6 +300,20 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, } } +template +__global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad, + size_t rows, size_t cols, size_t k) { + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + x_grad[i * cols + j] = 0; + } + for (size_t j = 0; j < k; ++j) { + size_t idx = indices[i * k + j]; + x_grad[i * cols + idx] = out_grad[i * k + j]; + } + } +} + inline static int GetDesiredBlockDim(int dim) { if (dim > 128) { return 256; @@ -478,7 +492,7 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) -template +template class TopkOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -540,6 +554,42 @@ class TopkOpCUDAKernel : public framework::OpKernel { } }; +template +class TopkOpGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(context.GetPlace()), true, + platform::errors::InvalidArgument("It must use CUDAPlace.")); + auto* x = context.Input("X"); + auto* out_grad = context.Input(framework::GradVarName("Out")); + auto* indices = context.Input("Indices"); + auto* x_grad = context.Output(framework::GradVarName("X")); + + T* x_grad_data = x_grad->mutable_data(context.GetPlace()); + const T* out_grad_data = out_grad->data(); + const int64_t* indices_data = indices->data(); + size_t k = indices->dims()[indices->dims().size() - 1]; + + framework::DDim xdims = x->dims(); + const size_t row = + framework::product(framework::slice_ddim(xdims, 0, xdims.size() - 1)); + const size_t col = xdims[xdims.size() - 1]; + const auto& dev_ctx = context.cuda_device_context(); + + const int kMaxHeight = 2048; + int gridx = row < kMaxHeight ? row : kMaxHeight; + switch (GetDesiredBlockDim(col)) { + FIXED_BLOCK_DIM( + AssignGrad<<>>( + x_grad_data, indices_data, out_grad_data, row, col, k)); + default: + PADDLE_THROW( + platform::errors::Unavailable("Error occurs when Assign Grad.")); + } + } +}; #undef FIXED_BLOCK_DIM_BASE #undef FIXED_BLOCK_DIM @@ -547,8 +597,27 @@ class TopkOpCUDAKernel : public framework::OpKernel { } // namespace paddle REGISTER_OP_CUDA_KERNEL( - top_k, paddle::operators::TopkOpCUDAKernel, - paddle::operators::TopkOpCUDAKernel, - paddle::operators::TopkOpCUDAKernel, - paddle::operators::TopkOpCUDAKernel, - paddle::operators::TopkOpCUDAKernel); + top_k, + paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel); + +REGISTER_OP_CUDA_KERNEL( + top_k_grad, + paddle::operators::TopkOpGradCUDAKernel, + paddle::operators::TopkOpGradCUDAKernel, + paddle::operators::TopkOpGradCUDAKernel, + paddle::operators::TopkOpGradCUDAKernel, + paddle::operators::TopkOpGradCUDAKernel); diff --git a/paddle/fluid/operators/top_k_op.h b/paddle/fluid/operators/top_k_op.h index 6b9db26060..1ba01d93ac 100644 --- a/paddle/fluid/operators/top_k_op.h +++ b/paddle/fluid/operators/top_k_op.h @@ -94,5 +94,35 @@ class TopkKernel : public framework::OpKernel { } }; +template +class TopkGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out_grad = context.Input(framework::GradVarName("Out")); + auto* indices = context.Input("Indices"); + auto* x_grad = context.Output(framework::GradVarName("X")); + + T* x_grad_data = x_grad->mutable_data(context.GetPlace()); + const T* out_grad_data = out_grad->data(); + const int64_t* indices_data = indices->data(); + size_t k = indices->dims()[indices->dims().size() - 1]; + + framework::DDim xdims = x->dims(); + const size_t row = + framework::product(framework::slice_ddim(xdims, 0, xdims.size() - 1)); + const size_t col = xdims[xdims.size() - 1]; + + memset(x_grad_data, 0, row * col * sizeof(T)); + + for (size_t i = 0; i < row; ++i) { + for (size_t j = 0; j < k; ++j) { + size_t idx = indices_data[i * k + j]; + x_grad_data[i * col + idx] = out_grad_data[i * k + j]; + } + } + } +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_top_k_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_top_k_ngraph_op.py index d80c72ee32..a42f781c65 100644 --- a/python/paddle/fluid/tests/unittests/ngraph/test_top_k_ngraph_op.py +++ b/python/paddle/fluid/tests/unittests/ngraph/test_top_k_ngraph_op.py @@ -15,7 +15,7 @@ from __future__ import print_function import unittest, sys sys.path.append("../") -from test_top_k_op import TestTopkOp, TestTopkOp3d, TestTopkOp2, TestTopkOp3, TestTopkOp4 +from test_top_k_op import TestTopkOp if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_top_k_op.py b/python/paddle/fluid/tests/unittests/test_top_k_op.py index 5327c0f5de..52d1fda0ae 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core class TestTopkOp(OpTest): @@ -24,7 +25,7 @@ class TestTopkOp(OpTest): self.variable_k = False self.set_args() self.op_type = "top_k" - self.dtype = np.float32 + self.dtype = np.float64 self.init_dtype() k = self.top_k @@ -49,106 +50,14 @@ class TestTopkOp(OpTest): pass def set_args(self): - self.row = 32 + self.row = 100 self.top_k = 1 def test_check_output(self): self.check_output() - -class TestTopkOpFp16(TestTopkOp): - def init_dtype(self): - self.dtype = np.float16 - - -class TestTopkOp3d(OpTest): - def setUp(self): - self.op_type = "top_k" - k = 1 - input = np.random.random((32, 2, 84)).astype("float32") - input_flat_2d = input.reshape(64, 84) - output = np.ndarray((64, k)) - indices = np.ndarray((64, k)).astype("int64") - - self.inputs = {'X': input} - self.attrs = {'k': k} - - for rowid in range(64): - row = input_flat_2d[rowid] - output[rowid] = np.sort(row)[::-1][:k] - indices[rowid] = row.argsort()[::-1][:k] - - self.outputs = { - 'Out': output.reshape((32, 2, k)), - 'Indices': indices.reshape((32, 2, k)) - } - - def test_check_output(self): - self.check_output() - - -class TestTopkOp1(OpTest): - def setUp(self): - self.op_type = "top_k" - k = 2 - m = 2056 - input = np.random.random(m).astype("float32") - output = np.ndarray(k) - indices = np.ndarray(k).astype("int64") - - self.inputs = {'X': input} - self.attrs = {'k': k} - - row = input - output = -np.sort(-row)[:k] - indices = (-row).argsort()[:k] - - self.outputs = {'Out': output, 'Indices': indices} - - def test_check_output(self): - self.check_output() - - -class TestTopkOp2(OpTest): - def setUp(self): - self.op_type = "top_k" - k = 1 - m = 2056 - input = np.random.random((m, 84)).astype("float32") - output = np.ndarray((m, k)) - indices = np.ndarray((m, k)).astype("int64") - - self.inputs = {'X': input} - self.attrs = {'k': k} - - for rowid in range(m): - row = input[rowid] - output[rowid] = -np.sort(-row)[:k] - indices[rowid] = (-row).argsort()[:k] - - self.outputs = {'Out': output, 'Indices': indices} - - def test_check_output(self): - self.check_output() - - -class TestTopkOp3(TestTopkOp): - def set_args(self): - self.row = 2056 - self.top_k = 3 - - -class TestTopkOp4(TestTopkOp): - def set_args(self): - self.row = 40000 - self.top_k = 1 - - -class TestTopkOp5(TestTopkOp): - def set_args(self): - self.row = 40000 - self.top_k = 3 - self.variable_k = True + def test_check_grad(self): + self.check_grad(set(['X']), 'Out') if __name__ == "__main__": -- GitLab