提交 44703329 编写于 作者: D dangqingqing

Make some operator correctly handle gradients for multi inputs.

上级 35420274
......@@ -75,8 +75,8 @@ class MulOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE(y_dims[1] == out_dims[1],
"Out@GRAD M X N must equal to Y dims 1, N ");
x_grad->Resize(x_dims);
y_grad->Resize(y_dims);
if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);
}
};
......
......@@ -31,13 +31,13 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Y = context.Input<Tensor>("Y");
auto* Z = context.Output<Tensor>("Out");
Z->mutable_data<T>(context.GetPlace());
auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out");
z->mutable_data<T>(context.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
math::matmul<Place, T>(*X, false, *Y, false, 1, Z, 0, device_context);
math::matmul<Place, T>(*x, false, *y, false, 1, z, 0, device_context);
}
};
......@@ -45,20 +45,24 @@ template <typename Place, typename T>
class MulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* X = ctx.Input<Tensor>("X");
auto* Y = ctx.Input<Tensor>("Y");
auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
dX->mutable_data<T>(ctx.GetPlace());
dY->mutable_data<T>(ctx.GetPlace());
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_);
// dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N
math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context);
// dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K
math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context);
if (dx) {
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
dx->mutable_data<T>(ctx.GetPlace());
math::matmul<Place, T>(*dout, false, *y, true, 1, dx, 0, device_context);
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
math::matmul<Place, T>(*x, true, *dout, false, 1, dy, 0, device_context);
}
}
};
......
......@@ -64,8 +64,10 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
auto dims0 = ctx.Input<Tensor>("X")->dims();
auto dims1 = ctx.Input<Tensor>("b")->dims();
PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0);
ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1);
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *db = ctx.Output<Tensor>(framework::GradVarName("b"));
if (dx) dx->Resize(dims0);
if (db) db->Resize(dims1);
}
};
......
......@@ -51,20 +51,24 @@ template <typename Place, typename T>
class RowwiseAddGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
auto* db = context.Output<Tensor>(framework::GradVarName("b"));
dX->mutable_data<T>(context.GetPlace());
db->mutable_data<T>(context.GetPlace());
auto OutGrad = EigenMatrix<T>::From(*dOut);
auto out_grad = EigenMatrix<T>::From(*dout);
auto place = context.GetEigenDevice<Place>();
EigenMatrix<T>::From(*dX).device(place) = OutGrad;
if (dx) {
dx->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::From(*dx).device(place) = out_grad;
}
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
// colwise add
Eigen::array<int, 1> dims{{0}}; /* dimension to reduce */
EigenVector<T>::Flatten(*db).device(place) = OutGrad.sum(dims);
if (db) {
db->mutable_data<T>(context.GetPlace());
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
// colwise add
Eigen::array<int, 1> dims{{0}}; /* dimension to reduce */
EigenVector<T>::Flatten(*db).device(place) = out_grad.sum(dims);
}
}
};
} // namespace operators
......
......@@ -50,8 +50,8 @@ class ScatterGradOp : public framework::OperatorWithKernel {
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto *Ref = ctx.Input<Tensor>("Ref");
dRef->Resize(Ref->dims());
dUpdates->Resize(Updates->dims());
if (dRef) dRef->Resize(Ref->dims());
if (dUpdates) dUpdates->Resize(Updates->dims());
}
};
......
......@@ -49,10 +49,12 @@ class ScatterGradientOpKernel : public framework::OpKernel {
auto *dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
// In place gradient: dRef = dO
dRef->ShareDataWith<T>(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Index]
Gather<T>(ctx.GetPlace(), dOut, Index, dUpdates);
if (dRef) dRef->ShareDataWith<T>(*dOut);
if (dUpdates) {
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Index]
Gather<T>(ctx.GetPlace(), dOut, Index, dUpdates);
}
}
};
......
......@@ -286,6 +286,9 @@ class GradientChecker(unittest.TestCase):
for no_grad in no_grad_set:
if no_grad not in in_names:
raise ValueError("no_grad should be in in_names")
if name in inputs_to_check:
raise ValueError("no_grad should not be in inputs_to_check")
backward_op = core.Operator.backward(forward_op, no_grad_set)
places = [core.CPUPlace()]
......@@ -301,9 +304,25 @@ class GradientChecker(unittest.TestCase):
check_names = [grad_var_name(name) for name in inputs_to_check]
for place in places:
# get analytical gradients according to different device
analytic_grads = self.__get_gradient(forward_op, backward_op,
input_vars, check_names, place)
# analytic_grads = self.__get_gradient(forward_op, backward_op,
# input_vars, check_names, place)
# In fact, the above two lines can be used to replace following
# codes. But most of the gradient operators need to handle the case
# where one of more of the gradient of the input is not needed.
# We change the unit test framework to explicitly test whether
# the operator correctly handles this through follow codes.
# In addtion, if all the inputs have no gradients, the NOP operator
# will be returned by core.Operator.backward(). The following codes
# do not test this case.
analytic_grads = []
for name in inputs_to_check:
no_grads = [name for name in no_grad_set]
no_grads.extend(filter(lambda x: x != name, inputs_to_check))
backward_op = core.Operator.backward(forward_op, set(no_grads))
# get analytical gradients according to different device
analytic_grads.extend(
self.__get_gradient(forward_op, backward_op, input_vars,
[grad_var_name(name)], place))
self.__assert_is_close(numeric_grads, analytic_grads, check_names,
max_relative_error,
"Gradient Check On %s" % str(place))
......@@ -16,13 +16,14 @@ class TestMulOp(unittest.TestCase):
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
class MulGradOpTest(GradientChecker):
class TestMulGradOp(GradientChecker):
def test_mul(self):
op = create_op("mul")
inputs = {
'X': np.random.random((32, 84)).astype("float32"),
'Y': np.random.random((84, 100)).astype("float32")
}
self.compare_grad(op, inputs)
# mul op will enlarge the relative error
self.check_grad(
op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册