提交 2a42a73d 编写于 作者: Z zchen0211

modify gather_op with test

上级 caaa5f86
......@@ -27,6 +27,9 @@ class GatherOp : public framework::OperatorWithKernel {
"Inputs of GatherOp must all be set");
int batch_size = ctx.Input<Tensor>(1)->dims()[0];
PADDLE_ENFORCE(batch_size > 0);
paddle::framework::DDim output_dims(ctx.Input<Tensor>(0)->dims());
output_dims[0] = batch_size;
ctx.Output<Tensor>(0)->Resize(output_dims);
}
};
......@@ -48,8 +51,8 @@ Y = X[Index]
class GatherGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
// ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
// ->Resize(ctx.Input<Tensor>("X")->dims());
ctx.Output<Tensor>("X" + framework::kGradVarSuffix)
->Resize(ctx.Input<Tensor>("X")->dims());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册