未验证 提交 a414e947 编写于 作者: K Kaipeng Deng 提交者: GitHub

fit for new GradMaker (#4425)

上级 afc0f055
......@@ -94,15 +94,13 @@ public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("gather_point_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Index", this->Input("Index"));
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
}
};
......
......@@ -102,15 +102,13 @@ class GroupPointsGradDescMaker : public framework::SingleGradOpMaker<T> {
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("group_points_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Idx", this->Input("Idx"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
}
};
......
......@@ -117,8 +117,7 @@ public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
void Apply(GradOpPtr<T> op) const override {
op->SetType("three_interp_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Weight", this->Input("Weight"));
......@@ -126,7 +125,6 @@ protected:
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册