提交 f6bffd4e 编写于 作者: Z zchen0211

gather_op modified

上级 2a42a73d
...@@ -51,8 +51,10 @@ Y = X[Index] ...@@ -51,8 +51,10 @@ Y = X[Index]
class GatherGradOp : public framework::OperatorWithKernel { class GatherGradOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("X" + framework::kGradVarSuffix) auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
->Resize(ctx.Input<Tensor>("X")->dims()); auto X = ctx.Input<Tensor>("X");
X_grad->Resize(X->dims());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册