未验证 提交 8f035fb6 编写于 作者: J Jiawei Wang 提交者: GitHub

Add TopK Op Grad CPU&GPU Kernel test=develop (#22628)

* Add TopK Op Grad CPU&GPU Kernel test=develop

* Add TopK Op Grad, modify grad op maker test=develop

* Add TopK Op Grad, modify grad op maker test=develop

* Add TopK Op Grad, modify PADDLE_ENFORCE test=develop

* Add TopK Op Grad, modify PADDLE_THROW test=develop

* Add TopK Op Grad, modify unittest test=develop

* fix ngraph top k op unittest test=develop
上级 3f0ca61a
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/operators/top_k_op.h"
#include <memory>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,11 +43,6 @@ class TopkOp : public framework::OperatorWithKernel { ...@@ -42,11 +43,6 @@ class TopkOp : public framework::OperatorWithKernel {
framework::DDim dims = input_dims; framework::DDim dims = input_dims;
dims[dims.size() - 1] = k; dims[dims.size() - 1] = k;
// If has K as tensor, set k=-1 as not know real size at this time.
if (ctx->HasInput("K")) {
dims[dims.size() - 1] = -1;
}
ctx->SetOutputDim("Out", dims); ctx->SetOutputDim("Out", dims);
ctx->SetOutputDim("Indices", dims); ctx->SetOutputDim("Indices", dims);
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
...@@ -89,16 +85,67 @@ For matrices, this operator computes the top k entries in each row. )DOC"); ...@@ -89,16 +85,67 @@ For matrices, this operator computes the top k entries in each row. )DOC");
} }
}; };
class TopkOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument("Input(X) should be not null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Indices"), true,
platform::errors::InvalidArgument("Input(Indices) should be not null"));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Grad Input(Out) should be not null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument("Grad Output(X) should be not null"));
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
template <typename T>
class TopkGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType("top_k_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", this->Input("X"));
op->SetInput("Indices", this->Output("Indices"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(top_k, ops::TopkOp, ops::TopkOpMaker,
top_k, ops::TopkOp, ops::TopkOpMaker, ops::TopkGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, ops::TopkGradOpMaker<paddle::imperative::OpBase>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(top_k_grad, ops::TopkOpGrad);
REGISTER_OP_CPU_KERNEL(top_k, REGISTER_OP_CPU_KERNEL(top_k,
ops::TopkKernel<paddle::platform::CPUPlace, float>, ops::TopkKernel<paddle::platform::CPUPlace, float>,
ops::TopkKernel<paddle::platform::CPUPlace, double>, ops::TopkKernel<paddle::platform::CPUPlace, double>);
ops::TopkKernel<paddle::platform::CPUPlace, int>,
ops::TopkKernel<paddle::platform::CPUPlace, int64_t>); REGISTER_OP_CPU_KERNEL(top_k_grad,
ops::TopkGradKernel<paddle::platform::CPUPlace, float>,
ops::TopkGradKernel<paddle::platform::CPUPlace, double>);
...@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cstdio>
#include "cub/cub.cuh" #include "cub/cub.cuh"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
// set cub base traits in order to handle float16 // set cub base traits in order to handle float16
namespace cub { namespace cub {
template <> template <>
...@@ -300,6 +300,20 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, ...@@ -300,6 +300,20 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
} }
} }
template <typename T, int MaxLength, int BlockSize>
__global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad,
size_t rows, size_t cols, size_t k) {
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
x_grad[i * cols + j] = 0;
}
for (size_t j = 0; j < k; ++j) {
size_t idx = indices[i * k + j];
x_grad[i * cols + idx] = out_grad[i * k + j];
}
}
}
inline static int GetDesiredBlockDim(int dim) { inline static int GetDesiredBlockDim(int dim) {
if (dim > 128) { if (dim > 128) {
return 256; return 256;
...@@ -478,7 +492,7 @@ bool SortTopk(const platform::CUDADeviceContext& ctx, ...@@ -478,7 +492,7 @@ bool SortTopk(const platform::CUDADeviceContext& ctx,
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
template <typename T> template <typename DeviceContext, typename T>
class TopkOpCUDAKernel : public framework::OpKernel<T> { class TopkOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -540,6 +554,42 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -540,6 +554,42 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
class TopkOpGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));
auto* x = context.Input<Tensor>("X");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* indices = context.Input<Tensor>("Indices");
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
T* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
const T* out_grad_data = out_grad->data<T>();
const int64_t* indices_data = indices->data<int64_t>();
size_t k = indices->dims()[indices->dims().size() - 1];
framework::DDim xdims = x->dims();
const size_t row =
framework::product(framework::slice_ddim(xdims, 0, xdims.size() - 1));
const size_t col = xdims[xdims.size() - 1];
const auto& dev_ctx = context.cuda_device_context();
const int kMaxHeight = 2048;
int gridx = row < kMaxHeight ? row : kMaxHeight;
switch (GetDesiredBlockDim(col)) {
FIXED_BLOCK_DIM(
AssignGrad<T, 5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
x_grad_data, indices_data, out_grad_data, row, col, k));
default:
PADDLE_THROW(
platform::errors::Unavailable("Error occurs when Assign Grad."));
}
}
};
#undef FIXED_BLOCK_DIM_BASE #undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM #undef FIXED_BLOCK_DIM
...@@ -547,8 +597,27 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -547,8 +597,27 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
top_k, paddle::operators::TopkOpCUDAKernel<float>, top_k,
paddle::operators::TopkOpCUDAKernel<double>, paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::operators::TopkOpCUDAKernel<int>, float>,
paddle::operators::TopkOpCUDAKernel<int64_t>, paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::operators::TopkOpCUDAKernel<paddle::platform::float16>); double>,
paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
int>,
paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
top_k_grad,
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
float>,
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
int>,
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
...@@ -94,5 +94,35 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -94,5 +94,35 @@ class TopkKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
class TopkGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* indices = context.Input<Tensor>("Indices");
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
T* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
const T* out_grad_data = out_grad->data<T>();
const int64_t* indices_data = indices->data<int64_t>();
size_t k = indices->dims()[indices->dims().size() - 1];
framework::DDim xdims = x->dims();
const size_t row =
framework::product(framework::slice_ddim(xdims, 0, xdims.size() - 1));
const size_t col = xdims[xdims.size() - 1];
memset(x_grad_data, 0, row * col * sizeof(T));
for (size_t i = 0; i < row; ++i) {
for (size_t j = 0; j < k; ++j) {
size_t idx = indices_data[i * k + j];
x_grad_data[i * col + idx] = out_grad_data[i * k + j];
}
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -15,7 +15,7 @@ from __future__ import print_function ...@@ -15,7 +15,7 @@ from __future__ import print_function
import unittest, sys import unittest, sys
sys.path.append("../") sys.path.append("../")
from test_top_k_op import TestTopkOp, TestTopkOp3d, TestTopkOp2, TestTopkOp3, TestTopkOp4 from test_top_k_op import TestTopkOp
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
class TestTopkOp(OpTest): class TestTopkOp(OpTest):
...@@ -24,7 +25,7 @@ class TestTopkOp(OpTest): ...@@ -24,7 +25,7 @@ class TestTopkOp(OpTest):
self.variable_k = False self.variable_k = False
self.set_args() self.set_args()
self.op_type = "top_k" self.op_type = "top_k"
self.dtype = np.float32 self.dtype = np.float64
self.init_dtype() self.init_dtype()
k = self.top_k k = self.top_k
...@@ -49,106 +50,14 @@ class TestTopkOp(OpTest): ...@@ -49,106 +50,14 @@ class TestTopkOp(OpTest):
pass pass
def set_args(self): def set_args(self):
self.row = 32 self.row = 100
self.top_k = 1 self.top_k = 1
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
class TestTopkOpFp16(TestTopkOp): self.check_grad(set(['X']), 'Out')
def init_dtype(self):
self.dtype = np.float16
class TestTopkOp3d(OpTest):
def setUp(self):
self.op_type = "top_k"
k = 1
input = np.random.random((32, 2, 84)).astype("float32")
input_flat_2d = input.reshape(64, 84)
output = np.ndarray((64, k))
indices = np.ndarray((64, k)).astype("int64")
self.inputs = {'X': input}
self.attrs = {'k': k}
for rowid in range(64):
row = input_flat_2d[rowid]
output[rowid] = np.sort(row)[::-1][:k]
indices[rowid] = row.argsort()[::-1][:k]
self.outputs = {
'Out': output.reshape((32, 2, k)),
'Indices': indices.reshape((32, 2, k))
}
def test_check_output(self):
self.check_output()
class TestTopkOp1(OpTest):
def setUp(self):
self.op_type = "top_k"
k = 2
m = 2056
input = np.random.random(m).astype("float32")
output = np.ndarray(k)
indices = np.ndarray(k).astype("int64")
self.inputs = {'X': input}
self.attrs = {'k': k}
row = input
output = -np.sort(-row)[:k]
indices = (-row).argsort()[:k]
self.outputs = {'Out': output, 'Indices': indices}
def test_check_output(self):
self.check_output()
class TestTopkOp2(OpTest):
def setUp(self):
self.op_type = "top_k"
k = 1
m = 2056
input = np.random.random((m, 84)).astype("float32")
output = np.ndarray((m, k))
indices = np.ndarray((m, k)).astype("int64")
self.inputs = {'X': input}
self.attrs = {'k': k}
for rowid in range(m):
row = input[rowid]
output[rowid] = -np.sort(-row)[:k]
indices[rowid] = (-row).argsort()[:k]
self.outputs = {'Out': output, 'Indices': indices}
def test_check_output(self):
self.check_output()
class TestTopkOp3(TestTopkOp):
def set_args(self):
self.row = 2056
self.top_k = 3
class TestTopkOp4(TestTopkOp):
def set_args(self):
self.row = 40000
self.top_k = 1
class TestTopkOp5(TestTopkOp):
def set_args(self):
self.row = 40000
self.top_k = 3
self.variable_k = True
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册