未验证 提交 78cd3dd5 编写于 作者: F Feiyu Chan 提交者: GitHub

fix kron_op: when only one input needs gradient, test=develop (#24269)

fix kron_op: when only one input needs gradient
上级 5dc069d0
...@@ -99,17 +99,15 @@ class KronGradOp : public framework::OperatorWithKernel { ...@@ -99,17 +99,15 @@ class KronGradOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_grad"); OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "kron_grad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "kron_grad"); framework::GradVarName("Out"), "kron_grad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "kron_grad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Y")), "Output",
framework::GradVarName("Y"), "kron_grad");
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); if (ctx->HasOutput(x_grad_name)) {
ctx->ShareLoD("X", /*->*/ x_grad_name); ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y")); }
ctx->ShareLoD("Y", /*->*/ y_grad_name); if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, ctx->GetInputDim("Y"));
}
} }
protected: protected:
......
...@@ -147,11 +147,14 @@ struct KronGradElemFunctor { ...@@ -147,11 +147,14 @@ struct KronGradElemFunctor {
index_b += stride_b_[i] * pos_bi; index_b += stride_b_[i] * pos_bi;
} }
size_t index_out_a = index_a * numel_b_ + index_b; if (dout_a_) {
size_t index_out_b = index_b * numel_a_ + index_a; size_t index_out_a = index_a * numel_b_ + index_b;
dout_a_[index_out_a] = dout_[idx] * B_[index_b];
dout_a_[index_out_a] = dout_[idx] * B_[index_b]; }
dout_b_[index_out_b] = dout_[idx] * A_[index_a]; if (dout_b_) {
size_t index_out_b = index_b * numel_a_ + index_a;
dout_b_[index_out_b] = dout_[idx] * A_[index_a];
}
} }
private: private:
...@@ -222,35 +225,50 @@ struct KronGradOpFunctor { ...@@ -222,35 +225,50 @@ struct KronGradOpFunctor {
// dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y) // dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y)
// dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x) // dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x)
framework::Tensor dout_x; framework::Tensor dout_x;
dout_x.mutable_data<T>({numel_x, numel_y}, dev_ctx.GetPlace()); T* p_dout_x = nullptr;
if (dx) {
dout_x.mutable_data<T>({numel_x, numel_y}, dev_ctx.GetPlace());
p_dout_x = dout_x.data<T>();
}
framework::Tensor dout_y; framework::Tensor dout_y;
dout_y.mutable_data<T>({numel_y, numel_x}, dev_ctx.GetPlace()); T* p_dout_y = nullptr;
if (dy) {
dout_y.mutable_data<T>({numel_y, numel_x}, dev_ctx.GetPlace());
p_dout_y = dout_y.data<T>();
}
platform::ForRange<DeviceContext> for_range(dev_ctx, numel); platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
KronGradElemFunctor<T> func(dout.data<T>(), x.data<T>(), y.data<T>(), KronGradElemFunctor<T> func(dout.data<T>(), x.data<T>(), y.data<T>(),
dout_x.data<T>(), dout_y.data<T>(), p_dout_x, p_dout_y, p_stride_dout, p_stride_x,
p_stride_dout, p_stride_x, p_stride_y, p_stride_y, p_shape_y, numel_x, numel_y, ndims);
p_shape_y, numel_x, numel_y, ndims);
for_range(func); for_range(func);
// reduce_sum along aixs 1 // reduce_sum along aixs 1
#if __NVCC__ #if __NVCC__
auto stream = dev_ctx.stream(); // it is a cuda device_context auto stream = dev_ctx.stream(); // it is a cuda device_context
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>( if (dx) {
dout_x, dx, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(), TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
stream); dout_x, dx, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>( stream);
dout_y, dy, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(), }
stream); if (dy) {
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
dout_y, dy, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
stream);
}
#else #else
auto eigen_dout_x = framework::EigenMatrix<T>::Reshape(dout_x, 1);
auto eigen_dout_y = framework::EigenMatrix<T>::Reshape(dout_y, 1);
auto eigen_vec_dx = framework::EigenVector<T>::Flatten(*dx);
auto eigen_vec_dy = framework::EigenVector<T>::Flatten(*dy);
auto* place = dev_ctx.eigen_device(); auto* place = dev_ctx.eigen_device();
Eigen::array<int, 1> reduce_dim = {1}; Eigen::array<int, 1> reduce_dim = {1};
eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim); if (dx) {
eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim); auto eigen_dout_x = framework::EigenMatrix<T>::Reshape(dout_x, 1);
auto eigen_vec_dx = framework::EigenVector<T>::Flatten(*dx);
eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim);
}
if (dy) {
auto eigen_dout_y = framework::EigenMatrix<T>::Reshape(dout_y, 1);
auto eigen_vec_dy = framework::EigenVector<T>::Flatten(*dy);
eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim);
}
#endif #endif
} }
}; };
...@@ -307,17 +325,33 @@ class KronGradKernel : public framework::OpKernel<T> { ...@@ -307,17 +325,33 @@ class KronGradKernel : public framework::OpKernel<T> {
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
dx->mutable_data<T>(ctx.GetPlace()); if (dx) {
dy->mutable_data<T>(ctx.GetPlace()); dx->mutable_data<T>(ctx.GetPlace());
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
}
int ndims = dout->dims().size(); int ndims = dout->dims().size();
framework::Tensor xx = UnsqueezeTo(*x, ndims); framework::Tensor xx = UnsqueezeTo(*x, ndims);
framework::Tensor dxx = UnsqueezeTo(*dx, ndims);
framework::Tensor yy = UnsqueezeTo(*y, ndims); framework::Tensor yy = UnsqueezeTo(*y, ndims);
framework::Tensor dyy = UnsqueezeTo(*dy, ndims);
framework::Tensor* pdxx = nullptr;
framework::Tensor* pdyy = nullptr;
framework::Tensor dxx;
framework::Tensor dyy;
if (dx) {
dxx = UnsqueezeTo(*dx, ndims);
pdxx = &dxx;
}
if (dy) {
dyy = UnsqueezeTo(*dy, ndims);
pdyy = &dyy;
}
KronGradOpFunctor<DeviceContext, T> func; KronGradOpFunctor<DeviceContext, T> func;
func(dev_ctx, *dout, xx, yy, &dxx, &dyy); func(dev_ctx, *dout, xx, yy, pdxx, pdyy);
} }
}; };
......
...@@ -42,6 +42,12 @@ class TestKronOp(OpTest): ...@@ -42,6 +42,12 @@ class TestKronOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X', 'Y'], 'Out') self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ignore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set('X'))
def test_check_grad_ignore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
class TestKronOp2(TestKronOp): class TestKronOp2(TestKronOp):
def setUp(self): def setUp(self):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# check no_grad_set is None # check no_grad_set is None
NOT_CHECK_OP_LIST = ['deformable_conv', 'row_conv'] NOT_CHECK_OP_LIST = ['deformable_conv', 'row_conv', 'kron']
# TODO(Shixiaowei02): Check if the items do not need fix. # TODO(Shixiaowei02): Check if the items do not need fix.
# no_grad_set has value in NEED_TO_FIX_OP_LIST # no_grad_set has value in NEED_TO_FIX_OP_LIST
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册