提交 5d1ac41b 编写于 作者: L lvmengsi 提交者: Kaipeng Deng

Double backward reduce mean (#17372)

* test=develop, double backward reduce_mean

* add comment. test=develop

* fix format. test=develop

* rename GradGrad -> DoubleGrad. test=develop

* fix op_use_default_grad_op_maker.spec. test=develop
上级 0cae5a36
......@@ -29,7 +29,6 @@ prelu
quantize
rank_loss
reduce_max
reduce_mean
reduce_min
reduce_prod
reduce_sum
......
......@@ -13,8 +13,77 @@
// limitations under the License.
#include "paddle/fluid/operators/reduce_ops/reduce_mean_op.h"
#include <memory>
#include <string>
#include <vector>
REGISTER_REDUCE_OP(reduce_mean);
namespace paddle {
namespace operators {
// NOTE(dengkaipeng): Input(Out) is unnecessary in reduce_mean_grad
// calcualtion, but will incur a reduce_mean_grad op after
// reduce_mean_grad_grad, delete Input(Out) here.
// This change has no effect on reduce_mean_grad calculations.
class ReduceMeanOpGradDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("reduce_mean_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return op;
}
};
class ReduceMeanDoubleGradMaker : public framework::GradOpDescMakerBase {
public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
std::vector<std::unique_ptr<framework::OpDesc>> ops;
auto x_grads = InputGrad("X");
auto x_gg = OutputGrad(framework::GradVarName("X")); // input ddx
if (!x_grads.empty()) {
auto* x_grad_op = new framework::OpDesc();
x_grad_op->SetType("scale");
x_grad_op->SetInput("X", x_gg);
x_grad_op->SetOutput("Out", x_grads);
x_grad_op->SetAttr("scale", 0.0f);
ops.emplace_back(x_grad_op);
}
auto out_grads = InputGrad(framework::GradVarName("Out"));
if (!out_grads.empty()) {
auto* out_grad_op = new framework::OpDesc();
out_grad_op->SetType("reduce_mean");
out_grad_op->SetInput("X", x_gg);
out_grad_op->SetAttrMap(Attrs());
out_grad_op->SetOutput("Out", out_grads);
ops.emplace_back(out_grad_op);
}
return ops;
}
};
} // namespace operators
} // namespace paddle
class __reduce_meanMaker__ : public ops::ReduceOpMaker {
protected:
virtual std::string GetName() const { return "reduce_mean"; }
virtual std::string GetOpType() const { return "Reduce reduce_mean"; }
};
REGISTER_OPERATOR(reduce_mean, ops::ReduceOp, __reduce_meanMaker__,
ops::ReduceMeanOpGradDescMaker);
REGISTER_OPERATOR(reduce_mean_grad, ops::ReduceGradOp,
ops::ReduceMeanDoubleGradMaker);
REGISTER_OP_CPU_KERNEL(reduce_mean,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
float, ops::MeanFunctor>,
......
......@@ -88,6 +88,10 @@ class ReduceGradKernel : public framework::OpKernel<T> {
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
output->mutable_data<T>(context.GetPlace());
// NOTE(dengkaipeng): Out is unnecessary in some reduce kernel and
// not be set as Input in grad Maker, use Out_grad to replace here
if (!input1) input1 = input2;
if (reduce_all) {
auto x = EigenVector<T>::Flatten(*input0);
auto x_reduce = EigenVector<T>::From(*input1);
......
......@@ -166,6 +166,29 @@ class TestElementwiseMulDoubleGradCheck(unittest.TestCase):
self.func(p)
class TestReduceMeanWithDimDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [7, 11]
eps = 0.05
dtype = np.float64
x = layers.data('x', shape, False, dtype)
x.persistable = True
y = layers.reduce_mean(x, dim=0)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestElementwiseMulBroadcastDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册