未验证 提交 443a713c 编写于 作者: F FlyingQianMM 提交者: GitHub

add backward gradient computation for op argsort (#22203)

* add backward gradient computation for op argsort test=developo

* use pre-commit test=develop
上级 46189b16
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/argsort_op.h"
#include <memory>
namespace paddle {
namespace operators {
......@@ -21,7 +22,7 @@ class ArgsortOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ArgsortOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -49,6 +50,24 @@ class ArgsortOp : public framework::OperatorWithKernel {
}
};
class ArgsortGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
class ArgsortOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -83,16 +102,42 @@ Output(Indices) gives the sorted order along the given axis Attr(axis).
}
};
template <typename T>
class ArgsortGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
std::unique_ptr<T> op(new T());
op->SetType("argsort_grad");
op->SetInput("Indices", this->Output("Indices"));
op->SetInput("X", this->Input("X"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ArgsortGradNoNeedBufferVarInference, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
argsort, ops::ArgsortOp, ops::ArgsortOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(argsort, ops::ArgsortOp, ops::ArgsortOpMaker,
ops::ArgsortGradOpMaker<paddle::framework::OpDesc>,
ops::ArgsortGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(argsort_grad, ops::ArgsortGradOp,
ops::ArgsortGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(argsort,
ops::ArgsortKernel<paddle::platform::CPUPlace, float>,
ops::ArgsortKernel<paddle::platform::CPUPlace, double>,
ops::ArgsortKernel<paddle::platform::CPUPlace, int>,
ops::ArgsortKernel<paddle::platform::CPUPlace, int64_t>);
REGISTER_OP_CPU_KERNEL(
argsort_grad, ops::ArgsortGradientKernel<paddle::platform::CPUPlace, float>,
ops::ArgsortGradientKernel<paddle::platform::CPUPlace, double>,
ops::ArgsortGradientKernel<paddle::platform::CPUPlace, int>,
ops::ArgsortGradientKernel<paddle::platform::CPUPlace, int64_t>);
......@@ -58,6 +58,19 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
}
}
template <typename T, typename IndType>
static __global__ void FillGrad(const T* dO, const IndType* indices, T* dX,
IndType num_rows, IndType num_cols) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
for (IndType j = row_id; j < num_rows; j += gridDim.x) {
for (IndType i = col_id; i < num_cols; i += blockDim.x) {
dX[j * num_cols + indices[j * num_cols + i]] = dO[j * num_cols + i];
}
}
}
// Sort by flag descending, True: descending. False: Ascending.
// Default is false.
template <typename T, typename IndType>
......@@ -160,6 +173,35 @@ void ArgFullSort(const platform::CUDADeviceContext& ctx, const Tensor* input,
temp_storage_bytes, cudaGetErrorString(err));
}
template <typename T, typename IndType>
void ArgFullAssign(const platform::CUDADeviceContext& ctx, const Tensor* dO,
const Tensor* indices, Tensor* dX, const IndType num_rows,
const IndType num_cols) {
auto cu_stream = ctx.stream();
auto ComputeBlockSize = [](IndType col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
};
int block_size = ComputeBlockSize(num_cols);
int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
// actually, int num_rows < max_grid_size
int grid_size = num_rows < maxGridDimX ? num_rows : maxGridDimX;
FillGrad<<<grid_size, block_size, 0, cu_stream>>>(
dO->data<T>(), indices->data<IndType>(), dX->data<T>(), num_rows,
num_cols);
}
template <typename T>
class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
public:
......@@ -234,6 +276,81 @@ class ArgsortOpCUDAKernel : public framework::OpKernel<T> {
}
};
template <typename T>
class ArgsortGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* indices = ctx.Input<Tensor>("Indices");
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto& place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;
auto in_dims = indices->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
int64_t numel = indices->numel();
// Special case for full sort, speedup ~190x.
if (axis == -1 || axis + 1 == in_dims.size()) {
const int64_t input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
ArgFullAssign<T, int64_t>(dev_ctx, dO, indices, dX, input_height,
input_width);
} else {
// if not full sort, do transpose first
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.push_back(i);
}
trans.push_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.push_back(i);
}
trans.push_back(axis);
framework::DDim trans_dims(in_dims);
for (int i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
}
Tensor trans_dO;
trans_dO.mutable_data<T>(trans_dims, ctx.GetPlace());
Tensor trans_ind;
trans_ind.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
int ndims = trans.size();
const auto& dev_ctx = ctx.cuda_device_context();
// Do transpose
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *dO,
&trans_dO, trans);
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, *indices, &trans_ind, trans);
const int64_t input_height = framework::product(
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
Tensor tmp_out;
tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());
ArgFullAssign<T, int64_t>(dev_ctx, &trans_dO, &trans_ind, &tmp_out,
input_height, input_width);
// transpose back
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, tmp_out, dX,
trans);
return;
}
}
};
} // namespace operators
} // namespace paddle
......@@ -243,3 +360,9 @@ REGISTER_OP_CUDA_KERNEL(
paddle::operators::ArgsortOpCUDAKernel<int>,
paddle::operators::ArgsortOpCUDAKernel<int64_t>,
paddle::operators::ArgsortOpCUDAKernel<paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
argsort_grad, paddle::operators::ArgsortGradOpCUDAKernel<float>,
paddle::operators::ArgsortGradOpCUDAKernel<double>,
paddle::operators::ArgsortGradOpCUDAKernel<int>,
paddle::operators::ArgsortGradOpCUDAKernel<int64_t>,
paddle::operators::ArgsortGradOpCUDAKernel<paddle::platform::float16>);
......@@ -68,6 +68,31 @@ static void FullSort(Type input_height, Type input_width, int input_dim,
}
}
}
template <typename T, typename Type>
static void FullAssign(Type input_height, Type input_width, int input_dim,
const framework::Tensor* input,
const framework::Tensor* indices, T* t_out) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (Type i = 0; i < input_height; ++i) {
if (input_dim == 1) {
auto e_input = EigenVector<T>::Flatten(*input);
auto e_indices = EigenVector<Type>::Flatten(*indices);
for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + e_indices(j)] = e_input(e_indices(j));
}
} else {
auto e_input = EigenMatrix<T>::Reshape(*input, input_dim - 1);
auto e_indices = EigenMatrix<Type>::Reshape(*indices, input_dim - 1);
for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + e_indices(i, j)] = e_input(i, e_indices(i, j));
}
}
}
}
template <typename DeviceContext, typename T>
class ArgsortKernel : public framework::OpKernel<T> {
public:
......@@ -142,5 +167,77 @@ class ArgsortKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class ArgsortGradientKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* indices = ctx.Input<Tensor>("Indices");
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
auto in_dims = indices->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto& place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;
// Do full assign
if (axis == -1 || axis + 1 == in_dims.size()) {
const int64_t input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
FullAssign<T, int64_t>(input_height, input_width, in_dims.size(), dO,
indices, dX->data<T>());
} else {
// If not full assign do transpose
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.push_back(i);
}
trans.push_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.push_back(i);
}
trans.push_back(axis);
framework::DDim trans_dims(in_dims);
for (size_t i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
}
Tensor trans_dO;
trans_dO.mutable_data<T>(trans_dims, ctx.GetPlace());
Tensor trans_ind;
trans_ind.mutable_data<int64_t>(trans_dims, ctx.GetPlace());
int ndims = trans.size();
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
// Do transpose
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_ctx, *dO,
&trans_dO, trans);
TransCompute<platform::CPUDeviceContext, int64_t>(
ndims, dev_ctx, *indices, &trans_ind, trans);
const int64_t input_height = framework::product(
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
Tensor tmp_out;
T* t_out = tmp_out.mutable_data<T>(trans_dims, ctx.GetPlace());
FullAssign<T, int64_t>(input_height, input_width, in_dims.size(),
&trans_dO, &trans_ind, t_out);
// transpose back
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_ctx, tmp_out, dX,
trans);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -48,7 +48,7 @@ class TestArgsortOp(OpTest):
self.axis = -1
def init_datatype(self):
self.dtype = "float32"
self.dtype = "float64"
def init_direction(self):
self.descending = False
......@@ -56,6 +56,9 @@ class TestArgsortOp(OpTest):
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestArgsortOpAxis0(TestArgsortOp):
def init_axis(self):
......@@ -146,5 +149,18 @@ class TestArgsortOpDescendingAxisNeg2(TestArgsortOpAxisNeg2):
self.descending = True
class TestArgsortOpFP32Axis(TestArgsortOp):
def init_datatype(self):
self.dtype = "float32"
class TestArgsortOpFP32DescendingAxis(TestArgsortOp):
def init_datatype(self):
self.dtype = "float32"
def init_direction(self):
self.descending = True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册