提交 a0aa9073 编写于 作者: Z Zhuoyuan 提交者: GitHub

Merge pull request #3540 from zchen0211/develop

Gather_op with python op passed
......@@ -45,6 +45,7 @@ cc_library(paddle_pybind SHARED
SRCS pybind.cc
DEPS pybind python backward
sgd_op
gather_op
add_op
mul_op
rowwise_add_op
......
......@@ -42,6 +42,7 @@ USE_OP(fill_zeros_like);
USE_OP_ITSELF(recurrent_op);
USE_OP(gaussian_random);
USE_OP(uniform_random);
USE_CPU_ONLY_OP(gather);
namespace paddle {
namespace framework {
......
......@@ -43,6 +43,7 @@ endfunction()
add_subdirectory(math)
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
op_library(gather_op SRCS gather_op.cc gather_op.cu)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <cstring>
#include "paddle/framework/ddim.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
......@@ -25,13 +26,13 @@ namespace operators {
// Implementation of CPU copy
template <typename T>
void CPUGather(const T* params, const int* indices, const int slice_size,
void CPUGather(const T* src, const int* indices, const int slice_size,
const int index_size, T* output) {
const size_t slice_bytes = slice_size * sizeof(T);
for (int i = 0; i < index_size; ++i) {
int index_ = indices[i];
memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes);
memcpy(output + i * slice_size, src + index_ * slice_size, slice_bytes);
}
}
......@@ -55,7 +56,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src,
int index_size = index->dims()[0];
auto src_dims = src->dims();
paddle::framework::DDim output_dims(src_dims);
framework::DDim output_dims(src_dims);
output_dims[0] = index_size;
// slice size
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/gather_op.h"
#include "paddle/framework/ddim.h"
namespace paddle {
namespace operators {
class GatherOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
int batch_size = ctx.Input<Tensor>("Index")->dims()[0];
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
framework::DDim output_dims(ctx.Input<Tensor>("X")->dims());
output_dims[0] = batch_size;
ctx.Output<Tensor>("Out")->Resize(output_dims);
}
};
class GatherGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
X_grad->Resize(X->dims());
}
};
class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GatherOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The source input of gather op");
AddInput("Index", "The index input of gather op");
AddOutput("Out", "The output of add op");
AddComment(R"DOC(
Gather Operator by selecting from the first axis,
Out = X[Index]
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(gather, ops::GatherOp, ops::GatherOpMaker, gather_grad,
ops::GatherGradOp);
REGISTER_OP_CPU_KERNEL(gather,
ops::GatherOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
gather_grad,
ops::GatherGradientOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/gather_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gather,
ops::GatherOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "gather.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "scatter.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class GatherOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
auto *Index = ctx.Input<Tensor>("Index");
auto *Y = ctx.Output<Tensor>("Out");
Y->mutable_data<T>(ctx.GetPlace());
Gather<T>(ctx.GetPlace(), X, Index, Y);
}
};
template <typename Place, typename T>
class GatherGradientOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *Index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
dX->mutable_data<T>(ctx.GetPlace());
ScatterUpdate<T>(ctx.GetPlace(), dO, Index, dX);
}
};
} // namespace operators
} // namespace paddle
......@@ -13,6 +13,7 @@ py_test(test_add_two_op SRCS test_add_two_op.py)
py_test(test_sigmoid_op SRCS test_sigmoid_op.py)
py_test(test_softmax_op SRCS test_softmax_op.py)
py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py)
py_test(test_gather_op SRCS test_gather_op.py)
py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py)
py_test(gradient_checker SRCS gradient_checker.py)
......
import unittest
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
import numpy
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
class TestGatherOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "gather"
xnp = numpy.random.random((10, 20)).astype("float32")
self.inputs = {
'X': xnp,
'Index': numpy.array([1, 3, 5]).astype("int32")
}
self.outputs = {'Out': self.inputs['X'][self.inputs['Index']]}
class TestGatherGradOp(GradientChecker):
def test_gather_grad(self):
print 'creating op'
op = create_op("gather")
print 'creating op done'
xnp = numpy.random.random((10, 20)).astype("float32")
inputs = {'X': xnp, 'Index': numpy.array([1, 3, 5]).astype("int32")}
print 'correct before check gradient'
self.check_grad(op, inputs, set("X"), "Out")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册