提交 53e71b44 编写于 作者: Z zchen0211

gather op bp passed

上级 f3df1054
......@@ -44,8 +44,6 @@ endfunction()
add_subdirectory(math)
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
op_library(gather_op SRCS gather_op.cc gather_op.cu)
# DEPS op_registry)
# cc_test(gather_op_test SRCS gather_op_test.cc DEPS gather_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
......
......@@ -27,13 +27,13 @@ namespace operators {
// Implementation of CPU copy
template <typename T>
void CPUGather(const T* params, const int* indices, const int slice_size,
void CPUGather(const T* src, const int* indices, const int slice_size,
const int index_size, T* output) {
const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) {
int index_ = indices[i];
memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes);
memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes);
}
}
......@@ -57,7 +57,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src,
int index_size = index->dims()[0];
auto src_dims = src->dims();
paddle::framework::DDim output_dims(src_dims);
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size
......
......@@ -26,9 +26,9 @@ class GatherOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override {
int batch_size = ctx.Input<Tensor>("Index")->dims()[0];
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
paddle::framework::DDim output_dims(ctx.Input<Tensor>("X")->dims());
framework::DDim output_dims(ctx.Input<Tensor>("X")->dims());
output_dims[0] = batch_size;
ctx.Output<Tensor>("Y")->Resize(output_dims);
ctx.Output<Tensor>("Out")->Resize(output_dims);
}
};
......@@ -51,11 +51,11 @@ class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op");
AddOutput("Y", "The output of add op");
AddOutput("Out", "The output of add op");
AddComment(R"DOC(
Gather Operator by selecting from the first axis,
Y = X[Index]
Out = X[Index]
)DOC");
}
};
......
......@@ -26,10 +26,10 @@ using Tensor = framework::Tensor;
template <typename Place, typename T>
class GatherOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X");
auto Index = ctx.Input<Tensor>("Index");
auto Y = ctx.Output<Tensor>("Y");
void Compute(const framework::ExecutionContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
auto *Index = ctx.Input<Tensor>("Index");
auto *Y = ctx.Output<Tensor>("Out");
Y->mutable_data<T>(ctx.GetPlace());
Gather<T>(ctx.GetPlace(), X, Index, Y);
......@@ -39,12 +39,13 @@ class GatherOpKernel : public framework::OpKernel {
template <typename Place, typename T>
class GatherGradientOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto Index = ctx.Input<Tensor>("Index");
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
void Compute(const framework::ExecutionContext &ctx) const override {
auto *Index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
ScatterUpdate<T>(ctx.GetPlace(), dY, Index, dX);
dX->mutable_data<T>(ctx.GetPlace());
ScatterUpdate<T>(ctx.GetPlace(), dO, Index, dX);
}
};
......
import unittest
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
import numpy
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
from op_test_util import OpTestMeta
class TestGatherOp(unittest.TestCase):
__metaclass__ = OpTestMeta
......@@ -17,7 +16,18 @@ class TestGatherOp(unittest.TestCase):
'X': xnp,
'Index': numpy.array([1, 3, 5]).astype("int32")
}
self.outputs = {'Y': self.inputs['X'][self.inputs['Index']]}
self.outputs = {'Out': self.inputs['X'][self.inputs['Index']]}
class TestGatherGradOp(GradientChecker):
def test_gather_grad(self):
print 'creating op'
op = create_op("gather")
print 'creating op done'
xnp = numpy.random.random((10, 20)).astype("float32")
inputs = {'X': xnp, 'Index': numpy.array([1, 3, 5]).astype("int32")}
print 'correct before check gradient'
self.check_grad(op, inputs, set("X"), "Out")
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册