提交 b7ee1e7d 编写于 作者: D dongzhihong

"backward check todo"

上级 789d6ed9
...@@ -42,18 +42,18 @@ template <typename Place, typename T> ...@@ -42,18 +42,18 @@ template <typename Place, typename T>
class RowwiseAddGradKernel : public OpKernel { class RowwiseAddGradKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto XGrad = context.Output<Tensor>(0); auto* XGrad = context.Output<Tensor>(0);
auto bGrad = context.Output<Tensor>(1); auto* bGrad = context.Output<Tensor>(1);
XGrad->mutable_data<T>(context.GetPlace()); XGrad->mutable_data<T>(context.GetPlace());
bGrad->mutable_data<T>(context.GetPlace()); bGrad->mutable_data<T>(context.GetPlace());
// I, O, OG => [X, b], [Out], [OutGrad] // I, O, OG => [X, b], [Out], [OutGrad]
auto OutGrad = EigenMatrix<T>::From(*context.Input<Tensor>(3)); auto OutGrad = EigenMatrix<T>::From(*context.Input<Tensor>(3));
EigenMatrix<T>::From(*XGrad).device(*(context.GetEigenDevice<Place>())) = EigenMatrix<T>::From(*XGrad).device(context.GetEigenDevice<Place>()) =
OutGrad; OutGrad;
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
EigenVector<T>::Flatten(*bGrad).device(*(context.GetEigenDevice<Place>())) = EigenVector<T>::Flatten(*bGrad).device(context.GetEigenDevice<Place>()) =
OutGrad.cumsum(1); // colwise add OutGrad.cumsum(1); // colwise add
} }
}; };
......
...@@ -15,5 +15,7 @@ class TestRowwiseAddOp(unittest.TestCase): ...@@ -15,5 +15,7 @@ class TestRowwiseAddOp(unittest.TestCase):
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
#TODO(dzh): rowwise_grad check
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册