提交 6075928d 编写于 作者: Z zchen0211

gather op added

上级 4d2adab7
...@@ -17,6 +17,8 @@ limitations under the License. */ ...@@ -17,6 +17,8 @@ limitations under the License. */
#include <cstring> #include <cstring>
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
......
...@@ -24,13 +24,9 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -24,13 +24,9 @@ class GatherOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
// PADDLE_ENFORCE(ctx.InputSize() == 2, "");
// PADDLE_ENFORCE(ctx.OutputSize() == 1, "");
// PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(0),
// "Inputs of GatherOp must all be set");
int batch_size = ctx.Input<Tensor>("Index")->dims()[0]; int batch_size = ctx.Input<Tensor>("Index")->dims()[0];
PADDLE_ENFORCE(batch_size > 0); PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
paddle::framework::DDim output_dims(ctx.Input<Tensor>(0)->dims()); paddle::framework::DDim output_dims(ctx.Input<Tensor>("X")->dims());
output_dims[0] = batch_size; output_dims[0] = batch_size;
ctx.Output<Tensor>("Y")->Resize(output_dims); ctx.Output<Tensor>("Y")->Resize(output_dims);
} }
......
...@@ -12,11 +12,12 @@ class TestGatherOp(unittest.TestCase): ...@@ -12,11 +12,12 @@ class TestGatherOp(unittest.TestCase):
def setUp(self): def setUp(self):
self.type = "gather" self.type = "gather"
xnp = numpy.random.random((10, 20)).astype("float32")
self.inputs = { self.inputs = {
'X': numpy.random.random((10, 20)).astype("float32"), 'X': xnp,
'Index': numpy.array([1, 3, 5]).astype("int") 'Index': numpy.array([1, 3, 5]).astype("int32")
} }
self.outputs = {'Y': self.input['X'][self.input['Index']]} self.outputs = {'Y': self.inputs['X'][self.inputs['Index']]}
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册