提交 4d2adab7 编写于 作者: Z zchen0211

gather op added with python unittest

上级 323d4233
......@@ -47,6 +47,7 @@ cc_library(paddle_pybind SHARED
SRCS pybind.cc
DEPS pybind python backward
sgd_op
gather_op
add_op
mul_op
rowwise_add_op
......
......@@ -42,6 +42,7 @@ USE_OP(fill_zeros_like);
USE_OP_ITSELF(recurrent_op);
USE_OP(gaussian_random);
USE_OP(uniform_random);
USE_CPU_ONLY_OP(gather);
namespace paddle {
namespace framework {
......
......@@ -43,7 +43,8 @@ endfunction()
add_subdirectory(math)
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_library(gather_op SRCS gather_op.cc DEPS op_registry)
op_library(gather_op SRCS gather_op.cc gather_op.cu)
# DEPS op_registry)
# cc_test(gather_op_test SRCS gather_op_test.cc DEPS gather_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
......
......@@ -19,17 +19,33 @@ namespace paddle {
namespace operators {
class GatherOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0),
"Inputs of GatherOp must all be set");
int batch_size = ctx.Input<Tensor>(1)->dims()[0];
// PADDLE_ENFORCE(ctx.InputSize() == 2, "");
// PADDLE_ENFORCE(ctx.OutputSize() == 1, "");
// PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0),
// "Inputs of GatherOp must all be set");
int batch_size = ctx.Input<Tensor>("Index")->dims()[0];
PADDLE_ENFORCE(batch_size > 0);
paddle::framework::DDim output_dims(ctx.Input<Tensor>(0)->dims());
output_dims[0] = batch_size;
ctx.Output<Tensor>(0)->Resize(output_dims);
ctx.Output<Tensor>("Y")->Resize(output_dims);
}
};
class GatherGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
X_grad->Resize(X->dims());
}
};
......@@ -47,25 +63,14 @@ Y = X[Index]
)DOC");
}
};
class GatherGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
X_grad->Resize(X->dims());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker);
REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad,
ops::GatherGradOp);
REGISTER_OP_CPU_KERNEL(gather,
ops::GatherOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(gather, gather_grad, ops::GatherGradOp);
REGISTER_OP_CPU_KERNEL(
gather_grad,
ops::GatherGradientOpKernel<paddle::platform::CPUPlace, float>);
......@@ -13,6 +13,7 @@ py_test(test_add_two_op SRCS test_add_two_op.py)
py_test(test_sigmoid_op SRCS test_sigmoid_op.py)
py_test(test_softmax_op SRCS test_softmax_op.py)
py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py)
py_test(test_gather_op SRCS test_gather_op.py)
py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py)
py_test(gradient_checker SRCS gradient_checker.py)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册