未验证 提交 c3a87e3d 编写于 作者: D Double_V 提交者: GitHub

support slice double grad, test=develop (#22166) (#22836)

* support slice double grad, test=develop
* merge two doublegradopmaker to one doublegradopmaker,test=develop
* change the shape of slice_OP's unittest, test=develop
上级 9ba52165
...@@ -291,6 +291,34 @@ class SliceOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -291,6 +291,34 @@ class SliceOpGradMaker : public framework::SingleGradOpMaker<T> {
} }
}; };
template <typename T>
class SliceDoubleOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto *bind = new T();
if (this->HasInput("StartsTensor")) {
bind->SetInput("StartsTensor", this->Input("StartsTensor"));
}
if (this->HasInput("EndsTensor")) {
bind->SetInput("EndsTensor", this->Input("EndsTensor"));
}
if (this->HasInput("StartsTensorList")) {
bind->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
bind->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
bind->SetInput("Input", this->OutputGrad(framework::GradVarName("Input")));
bind->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
bind->SetAttrMap(this->Attrs());
bind->SetType("slice");
return std::unique_ptr<T>(bind);
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SliceOpGradNoNeedBufferVarsInference, DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(SliceOpGradNoNeedBufferVarsInference,
"Input"); "Input");
...@@ -302,6 +330,8 @@ REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker, ...@@ -302,6 +330,8 @@ REGISTER_OPERATOR(slice, ops::SliceOp, ops::SliceOpMaker,
ops::SliceOpGradMaker<paddle::framework::OpDesc>, ops::SliceOpGradMaker<paddle::framework::OpDesc>,
ops::SliceOpGradMaker<paddle::imperative::OpBase>); ops::SliceOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad, REGISTER_OPERATOR(slice_grad, ops::SliceOpGrad,
ops::SliceDoubleOpGradMaker<paddle::framework::OpDesc>,
ops::SliceDoubleOpGradMaker<paddle::imperative::OpBase>,
ops::SliceOpGradNoNeedBufferVarsInference); ops::SliceOpGradNoNeedBufferVarsInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -43,6 +43,41 @@ class TestMulGradCheck(unittest.TestCase): ...@@ -43,6 +43,41 @@ class TestMulGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestSliceOpDoubleGradCheck(unittest.TestCase):
def func(self, place):
self.config()
out = fluid.layers.slice(
self.inputs, axes=self.axes, starts=self.starts, ends=self.ends)
gradient_checker.double_grad_check(
[self.inputs], out, x_init=self.x_arr, place=place)
def config(self):
self.starts = [1, 0, -1]
self.ends = [3, 3, 6]
self.axes = [0, 1, 2]
self.x_arr = np.random.random([3, 4, 5, 2]).astype("float64")
self.inputs = layers.create_parameter(
dtype="float64", shape=[3, 4, 5, 2], name='x')
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
self.func(place)
class TestSliceOpDoubleGradCheckCase3(TestSliceOpDoubleGradCheck):
def config(self):
self.starts = [1, -1, 1]
self.ends = [3, 3, 3]
self.axes = [0, 1, 2]
self.x_arr = np.random.random([3, 3, 3]).astype("float64")
self.inputs = layers.create_parameter(
dtype="float64", shape=[3, 3, 3], name='x3')
class TestReduceMeanWithDimDoubleGradCheck(unittest.TestCase): class TestReduceMeanWithDimDoubleGradCheck(unittest.TestCase):
@prog_scope() @prog_scope()
def func(self, place): def func(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册