提交 6075928d 编写于 作者: Z zchen0211

gather op added

上级 4d2adab7
......@@ -17,6 +17,8 @@ limitations under the License. */
#include <cstring>
#include "paddle/framework/ddim.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
......
......@@ -24,13 +24,9 @@ class GatherOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
// PADDLE_ENFORCE(ctx.InputSize() == 2, "");
// PADDLE_ENFORCE(ctx.OutputSize() == 1, "");
// PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0),
// "Inputs of GatherOp must all be set");
int batch_size = ctx.Input<Tensor>("Index")->dims()[0];
PADDLE_ENFORCE(batch_size > 0);
paddle::framework::DDim output_dims(ctx.Input<Tensor>(0)->dims());
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
paddle::framework::DDim output_dims(ctx.Input<Tensor>("X")->dims());
output_dims[0] = batch_size;
ctx.Output<Tensor>("Y")->Resize(output_dims);
}
......
......@@ -12,11 +12,12 @@ class TestGatherOp(unittest.TestCase):
def setUp(self):
self.type = "gather"
xnp = numpy.random.random((10, 20)).astype("float32")
self.inputs = {
'X': numpy.random.random((10, 20)).astype("float32"),
'Index': numpy.array([1, 3, 5]).astype("int")
'X': xnp,
'Index': numpy.array([1, 3, 5]).astype("int32")
}
self.outputs = {'Y': self.input['X'][self.input['Index']]}
self.outputs = {'Y': self.inputs['X'][self.inputs['Index']]}
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册