diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 21166354937c378dc3f295f9011d034eb24cfc7c..87efb900cd59e6adeb051e0e458f2b86c1b510c9 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -101,8 +101,8 @@ set(DEPS_OPS op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor net_op) op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) -op_library(cross_entropy_op DEPS cross_entropy_function) -op_library(softmax_with_cross_entropy_op DEPS cross_entropy_function softmax_function) +op_library(cross_entropy_op DEPS cross_entropy) +op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 6bea9817f1b6c76e68e2a3023bb9eac591aa894f..b39d4f0ac27bf0a8378344f852a602c5ecf4cf6a 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,17 +1,15 @@ if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc - im2col.cu DEPS cblas device_context operator) + im2col.cu DEPS cblas device_context operator) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) - nv_library(softmax_function SRCS softmax.cc softmax.cu - DEPS operator) - nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu - DEPS operator) + nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) + nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) else() cc_library(math_function SRCS math_function.cc im2col.cc - DEPS cblas device_context operator) + DEPS cblas device_context operator) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) - cc_library(softmax_function SRCS softmax.cc DEPS operator) - cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator) + cc_library(softmax SRCS softmax.cc DEPS operator) + cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) endif() cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/softmax.cc b/paddle/operators/math/softmax.cc index ac9f3c4bf61bf8e13faa17387f1112756db9a100..0ba8197ab8b64649c8adcf67771ba01eca7f1d10 100644 --- a/paddle/operators/math/softmax.cc +++ b/paddle/operators/math/softmax.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/operators/math/softmax.h" @@ -19,6 +19,7 @@ namespace operators { namespace math { template class SoftmaxFunctor; +template class SoftmaxGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/softmax.cu b/paddle/operators/math/softmax.cu index 4c3df0550e7ca6f4310db1d35cc34d5c73a2dd16..99f988d51e4b16c3f3bfd9c76b411bb53619603e 100644 --- a/paddle/operators/math/softmax.cu +++ b/paddle/operators/math/softmax.cu @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #define EIGEN_USE_GPU @@ -21,6 +21,7 @@ namespace operators { namespace math { template class SoftmaxFunctor; +template class SoftmaxGradFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/softmax.h b/paddle/operators/math/softmax.h index 225323f05ac9aacce55dfe4795315741ee2c8795..b7f627eee7f8fe68a83595a3390a55d438c97afb 100644 --- a/paddle/operators/math/softmax.h +++ b/paddle/operators/math/softmax.h @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" @@ -68,6 +68,37 @@ class SoftmaxFunctor { .broadcast(one_by_class)); } }; + +template +class SoftmaxGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor* y, const framework::Tensor* y_grad, + framework::Tensor* x_grad) { + auto softmax = EigenMatrix::From(*y); + auto softmax_grad = EigenMatrix::From(*y_grad); + auto logits_grad = EigenMatrix::From(*x_grad); + + const int kBatchDim = 0; + const int kClassDim = 1; + + const int batch_size = softmax.dimension(kBatchDim); + const int num_classes = softmax.dimension(kClassDim); + + Eigen::DSizes along_class(kClassDim); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + + auto dot = (softmax * softmax_grad) + .sum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class); + logits_grad.device(*context.GetEigenDevice()) = + (softmax_grad - dot) * softmax; + } +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 9858c4d9c2195c7bd0e767aaa86a950e0a791443..3c8fe04d2edeccc0e0d55aa2a28d71085ccf5145 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/operators/mul_op.h" @@ -35,12 +35,14 @@ class MulOp : public framework::OperatorWithKernel { int x_num_col_dims = ctx->Attrs().Get("x_num_col_dims"); int y_num_col_dims = ctx->Attrs().Get("y_num_col_dims"); - PADDLE_ENFORCE(x_dims.size() > x_num_col_dims, - "The rank of input tensor X should be larger than " - "`mul_op`'s `x_num_col_dims`."); - PADDLE_ENFORCE(y_dims.size() > y_num_col_dims, - "The rank of input tensor Y should be larger than " - "`mul_op`'s `y_num_col_dims`."); + PADDLE_ENFORCE_GT( + x_dims.size(), x_num_col_dims, + "The input tensor X's rank of MulOp should be larger than " + "x_num_col_dims."); + PADDLE_ENFORCE_GT( + y_dims.size(), y_num_col_dims, + "The input tensor Y's rank of MulOp should be larger than " + "y_num_col_dims."); auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims); diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index 17685ea654715f6996e17f6228f266c3aa1ee424..bc4af2f70427e684dfb531b8c61d68f28ae20794 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -24,9 +24,9 @@ class SequencePoolOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContextBase* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of SequenceAvgPoolOp should not be null."); + "Input(X) of SequencePoolOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of SequenceAvgPoolOp should not be null."); + "Output(Out) of SequencePoolOp should not be null."); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); } }; diff --git a/paddle/operators/sequence_softmax_op.cc b/paddle/operators/sequence_softmax_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..621779ab6133f56a43fb2d20c814ebed8762ea7d --- /dev/null +++ b/paddle/operators/sequence_softmax_op.cc @@ -0,0 +1,103 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_softmax_op.h" + +namespace paddle { +namespace operators { + +class SequenceSoftmaxOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceSoftmaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceSoftmaxOp should not be null."); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceSoftmaxOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension " + "of length 1."); + AddOutput("Out", + "(LoDTensor) 1-D or 2-D output LoDTensor with the 2-nd dimension " + "of length 1."); + AddComment(R"DOC( +SequenceSoftmaxOp computes softmax activation among all time-steps for each +sequence. The dimension of each time-step should be 1. Thus, the shape of +input Tensor can be either [N, 1] or [N], where N is the sum of all sequences' +lengths. + +Equation: + for i-th sequence in a mini-batch: + Out(X[lod[i]:lod[i+1]], :) = + exp(X[lod[i]:lod[i+1], :]) / sum(exp(X[lod[i]:lod[i+1], :])) + +For example, for a mini-batch of 3 sequences with variable-length, +each containing 2, 3, 2 time-steps, the lod of which is [0, 2, 5, 7], +then softmax will be computed among X[0:2, :], X[2:5, :], X[5:7, :] +and N turns out to be 7. +)DOC"); + } +}; + +class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Out"), + "Input(Out) of SequenceSoftmaxGradOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) of SequenceSoftmaxGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceSoftmaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) of SequenceSoftmaxOp should not be null."); + + PADDLE_ENFORCE_EQ( + ctx->GetInputDim("Out"), + ctx->GetInputDim(framework::GradVarName("Out")), + "Input(Out) and Input(Out@GRAD) of SequenceSoftmaxGradOp should be of " + "the same shape."); + + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(sequence_softmax, ops::SequenceSoftmaxOp, + ops::SequenceSoftmaxOpMaker, sequence_softmax_grad, + ops::SequenceSoftmaxGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_softmax, + ops::SequenceSoftmaxKernel); +REGISTER_OP_CPU_KERNEL( + sequence_softmax_grad, + ops::SequenceSoftmaxGradKernel); diff --git a/paddle/operators/sequence_softmax_op.cu b/paddle/operators/sequence_softmax_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..f2a1e3d5e31ef21b95a51b287bdd1d4aa9221e89 --- /dev/null +++ b/paddle/operators/sequence_softmax_op.cu @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/operators/sequence_softmax_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + sequence_softmax, + ops::SequenceSoftmaxKernel) +REGISTER_OP_GPU_KERNEL( + sequence_softmax_grad, + ops::SequenceSoftmaxGradKernel); diff --git a/paddle/operators/sequence_softmax_op.h b/paddle/operators/sequence_softmax_op.h new file mode 100644 index 0000000000000000000000000000000000000000..96d87c404d217280d74bd088e7a23f539ef6e7ce --- /dev/null +++ b/paddle/operators/sequence_softmax_op.h @@ -0,0 +1,94 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/softmax.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class SequenceSoftmaxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + auto lod = x->lod(); + auto dims = x->dims(); + + const size_t level = lod.size() - 1; + PADDLE_ENFORCE_EQ(dims[0], static_cast(lod[level].back()), + "The first dimension of Input(X) should be equal to the " + "sum of all sequences' lengths."); + PADDLE_ENFORCE_EQ(dims[0], x->numel(), + "The width of each timestep in Input(X) of " + "SequenceSoftmaxOp should be 1."); + + out->mutable_data(ctx.GetPlace()); + for (int i = 0; i < static_cast(lod[level].size()) - 1; ++i) { + int start_pos = static_cast(lod[level][i]); + int end_pos = static_cast(lod[level][i + 1]); + Tensor x_i = x->Slice(start_pos, end_pos); + Tensor out_i = out->Slice(start_pos, end_pos); + + // Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos) + framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos}); + x_i.Resize(dims_i); + out_i.Resize(dims_i); + math::SoftmaxFunctor()(ctx.device_context(), &x_i, &out_i); + } + } +}; + +template +class SequenceSoftmaxGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Input("Out"); + auto* out_grad = ctx.Input(framework::GradVarName("Out")); + auto* x = ctx.Input("X"); + auto* x_grad = ctx.Output(framework::GradVarName("X")); + + auto lod = x->lod(); + const size_t level = lod.size() - 1; + + x_grad->mutable_data(ctx.GetPlace()); + for (int i = 0; i < static_cast(lod[level].size()) - 1; ++i) { + int start_pos = static_cast(lod[level][i]); + int end_pos = static_cast(lod[level][i + 1]); + + Tensor out_i = out->Slice(start_pos, end_pos); + Tensor out_grad_i = out_grad->Slice(start_pos, end_pos); + Tensor x_grad_i = x_grad->Slice(start_pos, end_pos); + + // Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos) + framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos}); + out_i.Resize(dims_i); + out_grad_i.Resize(dims_i); + x_grad_i.Resize(dims_i); + math::SoftmaxGradFunctor()(ctx.device_context(), &out_i, + &out_grad_i, &x_grad_i); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 8fdda8b1dfc5dd40315682388dabe0bf2f2be555..2c08853f4f615bfe95f51aa20776ddddcdaa8f61 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -29,8 +29,8 @@ template class SoftmaxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto X = context.Input("X"); - auto Y = context.Output("Y"); + auto* X = context.Input("X"); + auto* Y = context.Output("Y"); // allocate memory on device. Y->mutable_data(context.GetPlace()); @@ -43,29 +43,14 @@ template class SoftmaxGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto Y = context.Input("Y"); - auto dY = context.Input(framework::GradVarName("Y")); - auto dX = context.Output(framework::GradVarName("X")); - dX->mutable_data(context.GetPlace()); - - const int batch_size = Y->dims()[0]; - const int class_num = Y->dims()[1]; - - Eigen::DSizes along_class(1); - Eigen::DSizes batch_by_one(batch_size, 1); - Eigen::DSizes one_by_class(1, class_num); + auto* Y = context.Input("Y"); + auto* dY = context.Input(framework::GradVarName("Y")); + auto* dX = context.Output(framework::GradVarName("X")); - auto Y_eigen = EigenMatrix::From(*Y); - auto dY_eigen = EigenMatrix::From(*dY); - auto dX_eigen = EigenMatrix::From(*dX); - auto place = context.GetEigenDevice(); + // allocate memory on device. + dX->mutable_data(context.GetPlace()); - auto dot = (Y_eigen * dY_eigen) - .sum(along_class) - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class); - dX_eigen.device(place) = (dY_eigen - dot) * Y_eigen; + math::SoftmaxGradFunctor()(context.device_context(), Y, dY, dX); } }; diff --git a/python/paddle/v2/framework/tests/test_sequence_softmax_op.py b/python/paddle/v2/framework/tests/test_sequence_softmax_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b54a56aa6d3f76baa4d1fc6ba8f963332deba002 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_sequence_softmax_op.py @@ -0,0 +1,38 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def stable_softmax(x): + """Compute the softmax of vector x in a numerically stable way.""" + shiftx = x - np.max(x).clip(-64.) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + +class TestSequenceSoftmaxOp(OpTest): + def setUp(self): + self.op_type = "sequence_softmax" + x = np.random.uniform(0.1, 1, (11, 1)).astype("float32") + lod = [[0, 4, 5, 8, 11]] + + out = np.zeros((11, 1)).astype("float32") + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + sub_x = sub_x.reshape(1, lod[0][i + 1] - lod[0][i]) + sub_out = stable_softmax(sub_x) + out[lod[0][i]:lod[0][i + 1], :] = sub_out.reshape( + lod[0][i + 1] - lod[0][i], 1) + + self.inputs = {"X": (x, lod)} + self.outputs = {"Out": out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out", max_relative_error=0.01) + + +if __name__ == "__main__": + unittest.main()