提交 b7ee1e7d 编写于 作者: D dongzhihong

"backward check todo"

上级 789d6ed9
......@@ -42,18 +42,18 @@ template <typename Place, typename T>
class RowwiseAddGradKernel : public OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
auto XGrad = context.Output<Tensor>(0);
auto bGrad = context.Output<Tensor>(1);
auto* XGrad = context.Output<Tensor>(0);
auto* bGrad = context.Output<Tensor>(1);
XGrad->mutable_data<T>(context.GetPlace());
bGrad->mutable_data<T>(context.GetPlace());
// I, O, OG => [X, b], [Out], [OutGrad]
auto OutGrad = EigenMatrix<T>::From(*context.Input<Tensor>(3));
EigenMatrix<T>::From(*XGrad).device(*(context.GetEigenDevice<Place>())) =
EigenMatrix<T>::From(*XGrad).device(context.GetEigenDevice<Place>()) =
OutGrad;
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
EigenVector<T>::Flatten(*bGrad).device(*(context.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(*bGrad).device(context.GetEigenDevice<Place>()) =
OutGrad.cumsum(1); // colwise add
}
};
......
......@@ -15,5 +15,7 @@ class TestRowwiseAddOp(unittest.TestCase):
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
#TODO(dzh): rowwise_grad check
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册