未验证 提交 a414e947 编写于 作者: K Kaipeng Deng 提交者: GitHub

fit for new GradMaker (#4425)

上级 afc0f055
...@@ -94,15 +94,13 @@ public: ...@@ -94,15 +94,13 @@ public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType("gather_point_grad"); op->SetType("gather_point_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Index", this->Input("Index")); op->SetInput("Index", this->Input("Index"));
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
} }
}; };
......
...@@ -102,15 +102,13 @@ class GroupPointsGradDescMaker : public framework::SingleGradOpMaker<T> { ...@@ -102,15 +102,13 @@ class GroupPointsGradDescMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType("group_points_grad"); op->SetType("group_points_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Idx", this->Input("Idx")); op->SetInput("Idx", this->Input("Idx"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
} }
}; };
......
...@@ -117,8 +117,7 @@ public: ...@@ -117,8 +117,7 @@ public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker; using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected: protected:
std::unique_ptr<T> Apply() const override { void Apply(GradOpPtr<T> op) const override {
auto* op = new T();
op->SetType("three_interp_grad"); op->SetType("three_interp_grad");
op->SetInput("X", this->Input("X")); op->SetInput("X", this->Input("X"));
op->SetInput("Weight", this->Input("Weight")); op->SetInput("Weight", this->Input("Weight"));
...@@ -126,7 +125,6 @@ protected: ...@@ -126,7 +125,6 @@ protected:
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs()); op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册