未验证 提交 78cd3dd5 编写于 作者: F Feiyu Chan 提交者: GitHub

fix kron_op: when only one input needs gradient, test=develop (#24269)

fix kron_op: when only one input needs gradient
上级 5dc069d0
......@@ -99,17 +99,15 @@ class KronGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "kron_grad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "kron_grad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Y")), "Output",
framework::GradVarName("Y"), "kron_grad");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
ctx->ShareLoD("Y", /*->*/ y_grad_name);
}
}
protected:
......
......@@ -147,12 +147,15 @@ struct KronGradElemFunctor {
index_b += stride_b_[i] * pos_bi;
}
if (dout_a_) {
size_t index_out_a = index_a * numel_b_ + index_b;
size_t index_out_b = index_b * numel_a_ + index_a;
dout_a_[index_out_a] = dout_[idx] * B_[index_b];
}
if (dout_b_) {
size_t index_out_b = index_b * numel_a_ + index_a;
dout_b_[index_out_b] = dout_[idx] * A_[index_a];
}
}
private:
const T* dout_;
......@@ -222,35 +225,50 @@ struct KronGradOpFunctor {
// dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y)
// dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x)
framework::Tensor dout_x;
T* p_dout_x = nullptr;
if (dx) {
dout_x.mutable_data<T>({numel_x, numel_y}, dev_ctx.GetPlace());
p_dout_x = dout_x.data<T>();
}
framework::Tensor dout_y;
T* p_dout_y = nullptr;
if (dy) {
dout_y.mutable_data<T>({numel_y, numel_x}, dev_ctx.GetPlace());
p_dout_y = dout_y.data<T>();
}
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
KronGradElemFunctor<T> func(dout.data<T>(), x.data<T>(), y.data<T>(),
dout_x.data<T>(), dout_y.data<T>(),
p_stride_dout, p_stride_x, p_stride_y,
p_shape_y, numel_x, numel_y, ndims);
p_dout_x, p_dout_y, p_stride_dout, p_stride_x,
p_stride_y, p_shape_y, numel_x, numel_y, ndims);
for_range(func);
// reduce_sum along aixs 1
#if __NVCC__
auto stream = dev_ctx.stream(); // it is a cuda device_context
if (dx) {
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
dout_x, dx, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
stream);
}
if (dy) {
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
dout_y, dy, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
stream);
}
#else
auto eigen_dout_x = framework::EigenMatrix<T>::Reshape(dout_x, 1);
auto eigen_dout_y = framework::EigenMatrix<T>::Reshape(dout_y, 1);
auto eigen_vec_dx = framework::EigenVector<T>::Flatten(*dx);
auto eigen_vec_dy = framework::EigenVector<T>::Flatten(*dy);
auto* place = dev_ctx.eigen_device();
Eigen::array<int, 1> reduce_dim = {1};
if (dx) {
auto eigen_dout_x = framework::EigenMatrix<T>::Reshape(dout_x, 1);
auto eigen_vec_dx = framework::EigenVector<T>::Flatten(*dx);
eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim);
}
if (dy) {
auto eigen_dout_y = framework::EigenMatrix<T>::Reshape(dout_y, 1);
auto eigen_vec_dy = framework::EigenVector<T>::Flatten(*dy);
eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim);
}
#endif
}
};
......@@ -307,17 +325,33 @@ class KronGradKernel : public framework::OpKernel<T> {
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
}
int ndims = dout->dims().size();
framework::Tensor xx = UnsqueezeTo(*x, ndims);
framework::Tensor dxx = UnsqueezeTo(*dx, ndims);
framework::Tensor yy = UnsqueezeTo(*y, ndims);
framework::Tensor dyy = UnsqueezeTo(*dy, ndims);
framework::Tensor* pdxx = nullptr;
framework::Tensor* pdyy = nullptr;
framework::Tensor dxx;
framework::Tensor dyy;
if (dx) {
dxx = UnsqueezeTo(*dx, ndims);
pdxx = &dxx;
}
if (dy) {
dyy = UnsqueezeTo(*dy, ndims);
pdyy = &dyy;
}
KronGradOpFunctor<DeviceContext, T> func;
func(dev_ctx, *dout, xx, yy, &dxx, &dyy);
func(dev_ctx, *dout, xx, yy, pdxx, pdyy);
}
};
......
......@@ -42,6 +42,12 @@ class TestKronOp(OpTest):
def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ignore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set('X'))
def test_check_grad_ignore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
class TestKronOp2(TestKronOp):
def setUp(self):
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# check no_grad_set is None
NOT_CHECK_OP_LIST = ['deformable_conv', 'row_conv']
NOT_CHECK_OP_LIST = ['deformable_conv', 'row_conv', 'kron']
# TODO(Shixiaowei02): Check if the items do not need fix.
# no_grad_set has value in NEED_TO_FIX_OP_LIST
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册