提交 4d2adab7 编写于 作者: Z zchen0211

gather op added with python unittest

上级 323d4233
...@@ -47,6 +47,7 @@ cc_library(paddle_pybind SHARED ...@@ -47,6 +47,7 @@ cc_library(paddle_pybind SHARED
SRCS pybind.cc SRCS pybind.cc
DEPS pybind python backward DEPS pybind python backward
sgd_op sgd_op
gather_op
add_op add_op
mul_op mul_op
rowwise_add_op rowwise_add_op
......
...@@ -42,6 +42,7 @@ USE_OP(fill_zeros_like); ...@@ -42,6 +42,7 @@ USE_OP(fill_zeros_like);
USE_OP_ITSELF(recurrent_op); USE_OP_ITSELF(recurrent_op);
USE_OP(gaussian_random); USE_OP(gaussian_random);
USE_OP(uniform_random); USE_OP(uniform_random);
USE_CPU_ONLY_OP(gather);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -43,7 +43,8 @@ endfunction() ...@@ -43,7 +43,8 @@ endfunction()
add_subdirectory(math) add_subdirectory(math)
cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_library(gather_op SRCS gather_op.cc DEPS op_registry) op_library(gather_op SRCS gather_op.cc gather_op.cu)
# DEPS op_registry)
# cc_test(gather_op_test SRCS gather_op_test.cc DEPS gather_op) # cc_test(gather_op_test SRCS gather_op_test.cc DEPS gather_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
......
...@@ -19,17 +19,33 @@ namespace paddle { ...@@ -19,17 +19,33 @@ namespace paddle {
namespace operators { namespace operators {
class GatherOp : public framework::OperatorWithKernel { class GatherOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, ""); // PADDLE_ENFORCE(ctx.InputSize() == 2, "");
PADDLE_ENFORCE(ctx.OutputSize() == 1, ""); // PADDLE_ENFORCE(ctx.OutputSize() == 1, "");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0), // PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0),
"Inputs of GatherOp must all be set"); // "Inputs of GatherOp must all be set");
int batch_size = ctx.Input<Tensor>(1)->dims()[0]; int batch_size = ctx.Input<Tensor>("Index")->dims()[0];
PADDLE_ENFORCE(batch_size > 0); PADDLE_ENFORCE(batch_size > 0);
paddle::framework::DDim output_dims(ctx.Input<Tensor>(0)->dims()); paddle::framework::DDim output_dims(ctx.Input<Tensor>(0)->dims());
output_dims[0] = batch_size; output_dims[0] = batch_size;
ctx.Output<Tensor>(0)->Resize(output_dims); ctx.Output<Tensor>("Y")->Resize(output_dims);
}
};
class GatherGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
X_grad->Resize(X->dims());
} }
}; };
...@@ -47,25 +63,14 @@ Y = X[Index] ...@@ -47,25 +63,14 @@ Y = X[Index]
)DOC"); )DOC");
} }
}; };
class GatherGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
X_grad->Resize(X->dims());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker); REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad,
ops::GatherGradOp);
REGISTER_OP_CPU_KERNEL(gather, REGISTER_OP_CPU_KERNEL(gather,
ops::GatherOpKernel<paddle::platform::CPUPlace, float>); ops::GatherOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(gather, gather_grad, ops::GatherGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
gather_grad, gather_grad,
ops::GatherGradientOpKernel<paddle::platform::CPUPlace, float>); ops::GatherGradientOpKernel<paddle::platform::CPUPlace, float>);
...@@ -13,6 +13,7 @@ py_test(test_add_two_op SRCS test_add_two_op.py) ...@@ -13,6 +13,7 @@ py_test(test_add_two_op SRCS test_add_two_op.py)
py_test(test_sigmoid_op SRCS test_sigmoid_op.py) py_test(test_sigmoid_op SRCS test_sigmoid_op.py)
py_test(test_softmax_op SRCS test_softmax_op.py) py_test(test_softmax_op SRCS test_softmax_op.py)
py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py) py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py)
py_test(test_gather_op SRCS test_gather_op.py)
py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py) py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py)
py_test(gradient_checker SRCS gradient_checker.py) py_test(gradient_checker SRCS gradient_checker.py)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册