diff --git a/paddle/fluid/operators/lookup_table_op.cc b/paddle/fluid/operators/lookup_table_op.cc index 3e8f3ec5c5cd683343bcbdfc2388bd37c25e00f9..d77b095c5d783a2a9fab87eb8b458117a6a3d225 100644 --- a/paddle/fluid/operators/lookup_table_op.cc +++ b/paddle/fluid/operators/lookup_table_op.cc @@ -32,11 +32,16 @@ class LookupTableOp : public framework::OperatorWithKernel { auto table_dims = ctx->GetInputDim("W"); auto ids_dims = ctx->GetInputDim("Ids"); + int ids_rank = ids_dims.size(); - PADDLE_ENFORCE_EQ(ids_dims.size(), 2); - PADDLE_ENFORCE_EQ(ids_dims[1], 1); + PADDLE_ENFORCE_EQ(table_dims.size(), 2); + PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1, + "The last dimension of the 'Ids' tensor must be 1."); - ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); + auto output_dims = + framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1)); + output_dims.push_back(table_dims[1]); + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); if (ctx->GetOutputsVarType("Out")[0] == framework::proto::VarType::LOD_TENSOR) { @@ -61,8 +66,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Ids", "An input with type int32 or int64 " "contains the ids to be looked up in W. " - "Ids must be a column vector with rank = 2. " - "The 2nd dimension size must be 1."); + "The last dimension size must be 1."); AddOutput("Out", "The lookup results, which have the same type as W."); AddAttr("is_sparse", "(boolean, default false) " diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu index 27483372b93a850d313445386c7973838c4a0710..74823dab09cac358f647c074ac2f2ee2fed17e55 100644 --- a/paddle/fluid/operators/lookup_table_op.cu +++ b/paddle/fluid/operators/lookup_table_op.cu @@ -118,28 +118,31 @@ class LookupTableGradCUDAKernel : public framework::OpKernel { auto *d_table = context.Output(framework::GradVarName("W")); auto *ids_data = ids->data(); - auto ids_dim = ids->dims(); + int64_t ids_num = ids->numel(); auto stream = dev_ctx.stream(); // copy GPU memory to CPU pinned memory framework::Vector new_rows; - new_rows.resize(ids_dim[0]); + new_rows.resize(ids_num); auto gpu_place = boost::get(context.GetPlace()); // TODO(yuyang18): Strange code here. memory::Copy(platform::CPUPlace(), new_rows.CUDAMutableData(context.GetPlace()), gpu_place, - ids_data, ids_dim[0] * sizeof(int64_t), stream); + ids_data, ids_num * sizeof(int64_t), stream); d_table->set_rows(new_rows); auto *d_table_value = d_table->mutable_value(); - d_table_value->Resize({ids_dim[0], table->dims()[1]}); + d_table_value->Resize({ids_num, table->dims()[1]}); d_table_value->mutable_data(context.GetPlace()); auto *d_table_data = d_table_value->data(); auto *d_output_data = d_output->data(); - PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims()); + auto d_output_dims = d_output->dims(); + PADDLE_ENFORCE_EQ( + d_table_value->dims(), + framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1)); memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data, d_output->numel() * sizeof(T), stream); diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index c9f074ca0e8dafb374dc9368165df5af5053a6b8..f5c10ced8305b64c6386c5051804f8c9a8f71802 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -109,17 +109,17 @@ class LookupTableGradKernel : public framework::OpKernel { auto *d_table = context.Output(framework::GradVarName("W")); auto *ids_data = ids->data(); - auto ids_dim = ids->dims(); + int64_t ids_num = ids->numel(); framework::Vector new_rows; - new_rows.reserve(ids_dim[0]); - for (int64_t i = 0; i < ids_dim[0]; i++) { + new_rows.reserve(ids_num); + for (int64_t i = 0; i < ids_num; i++) { new_rows.push_back(ids_data[i]); } d_table->set_rows(new_rows); auto *d_table_value = d_table->mutable_value(); - d_table_value->Resize({ids_dim[0], table_dim[1]}); + d_table_value->Resize({ids_num, table_dim[1]}); d_table_value->mutable_data(context.GetPlace()); d_table->set_height(table_dim[0]); @@ -127,7 +127,10 @@ class LookupTableGradKernel : public framework::OpKernel { auto *d_output_data = d_output->data(); auto *d_table_data = d_table_value->data(); - PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims()); + auto d_output_dims = d_output->dims(); + PADDLE_ENFORCE_EQ( + d_table_value->dims(), + framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1)); memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); } else { auto *ids = context.Input("Ids"); @@ -135,10 +138,9 @@ class LookupTableGradKernel : public framework::OpKernel { auto *d_table = context.Output(framework::GradVarName("W")); auto *ids_data = ids->data(); - auto ids_dim = ids->dims(); int N = table_dim[0]; - int D = d_output->dims()[1]; + int D = table_dim[1]; auto *d_output_data = d_output->data(); auto *d_table_data = d_table->mutable_data(context.GetPlace()); diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index 5596fa0648ccc151bc0d11de9c556599428a8d71..2bdb23e999621b10799b5163f326bc4b66a437e6 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -30,8 +30,16 @@ class SoftmaxCUDNNKernel : public framework::OpKernel { // allocate memory on device. Out->mutable_data(context.GetPlace()); + auto dims = X->dims(); + auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); + framework::LoDTensor flattened_x; + framework::LoDTensor flattened_out; + flattened_x.ShareDataWith(*X).Resize(flattened_dims); + flattened_out.ShareDataWith(*Out).Resize(flattened_dims); + math::SoftmaxCUDNNFunctor()( - context.template device_context(), X, Out); + context.template device_context(), + &flattened_x, &flattened_out); } }; @@ -46,9 +54,18 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel { // allocate memory on device. dX->mutable_data(context.GetPlace()); + auto dims = Out->dims(); + auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); + framework::LoDTensor flattened_out; + framework::LoDTensor flattened_d_out; + framework::LoDTensor flattened_d_x; + flattened_out.ShareDataWith(*Out).Resize(flattened_dims); + flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims); + flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims); + math::SoftmaxGradCUDNNFunctor()( - context.template device_context(), Out, - dOut, dX); + context.template device_context(), + &flattened_out, &flattened_d_out, &flattened_d_x); } }; diff --git a/paddle/fluid/operators/softmax_mkldnn_op.cc b/paddle/fluid/operators/softmax_mkldnn_op.cc index 6668e6b9e917eea7ba4a80ac78917b73eb827208..01819f53e3ab0973f6140c5a81f18f954b6a0376 100644 --- a/paddle/fluid/operators/softmax_mkldnn_op.cc +++ b/paddle/fluid/operators/softmax_mkldnn_op.cc @@ -26,9 +26,9 @@ using paddle::platform::MKLDNNMemDesc; using mkldnn::memory; // Note: paddle has also "memory" namespace using mkldnn::primitive; -using mkldnn::softmax_forward; -using mkldnn::softmax_backward; using mkldnn::prop_kind; +using mkldnn::softmax_backward; +using mkldnn::softmax_forward; using mkldnn::stream; using platform::to_void_cast; @@ -113,17 +113,27 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel { auto mkldnn_engine = dev_ctx.GetEngine(); const Tensor* input = ctx.Input("X"); Tensor* output = ctx.Output("Out"); - PADDLE_ENFORCE(input->dims().size() == 2UL, - "The input of softmax op must be a 2D matrix."); - const T* input_data = input->data(); - // allocate memory for output - T* output_data = output->mutable_data(ctx.GetPlace()); - std::vector src_tz = paddle::framework::vectorize2int(input->dims()); - std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); - // MKL-DNN does support softmax over selected axis. Having 2D Tensor, - // we will make normalization after final eg. axis: 1 - PADDLE_ENFORCE(((src_tz[0] == dst_tz[0]) && (src_tz[1] == dst_tz[1])), - "Softmax input and output dimensions should match"); + PADDLE_ENFORCE_EQ( + input->dims(), output->dims(), + "The shape of softmax's input and output must be identical."); + + // make sure 'output' holds memory, which will be shared by + // 'flattened_output' later. + output->mutable_data(ctx.GetPlace()); + + // flatten input and output to 2-D matrixs + auto dims = input->dims(); // input and output share the same shape + auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); + framework::Tensor flattened_input; + framework::Tensor flattened_output; + flattened_input.ShareDataWith(*input).Resize(flattened_dims); + flattened_output.ShareDataWith(*output).Resize(flattened_dims); + + const T* input_data = flattened_input.data(); + T* output_data = flattened_output.mutable_data(ctx.GetPlace()); + + std::vector src_tz = paddle::framework::vectorize2int(flattened_dims); + std::vector dst_tz = src_tz; // Same memory descriptor to be used for input and output memory::dims softmax_tz = {src_tz[0], src_tz[1]}; // Generate keys for storing/retriving primitives for this operator @@ -174,23 +184,34 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel { auto& dev_ctx = ctx.template device_context(); auto mkldnn_engine = dev_ctx.GetEngine(); const Tensor* output = ctx.Input("Out"); - const T* dst_data = output->data(); - auto* dout = ctx.template Input(framework::GradVarName("Out")); - const auto* diff_dst_ptr = dout->template data(); - auto* dx = ctx.template Output(framework::GradVarName("X")); - T* diff_src_ptr = dx->template mutable_data(ctx.GetPlace()); - std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + PADDLE_ENFORCE_EQ( + dout->dims(), dx->dims(), + "The shape of softmax_grad's input and output must be identical."); + + // make sure 'dx' holds memory, which will be shared by 'flattened_dx' + // later. + dx->template mutable_data(ctx.GetPlace()); + + auto dims = dout->dims(); // input and output share the same shape + auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); + framework::Tensor flattened_output; + framework::Tensor flattened_dout; + framework::Tensor flattened_dx; + flattened_output.ShareDataWith(*output).Resize(flattened_dims); + flattened_dout.ShareDataWith(*dout).Resize(flattened_dims); + flattened_dx.ShareDataWith(*dx).Resize(flattened_dims); + + const T* dst_data = flattened_output.data(); + const T* diff_dst_ptr = flattened_dout.template data(); + T* diff_src_ptr = flattened_dx.template mutable_data(ctx.GetPlace()); + + std::vector dst_tz = paddle::framework::vectorize2int(flattened_dims); std::vector src_tz(dst_tz); - PADDLE_ENFORCE(output->dims().size() == 2UL, - "The input of softmax op must be a 2D matrix."); - // MKL-DNN does support softmax over selected axis. Having 2D Tensor, - // we will make normalization after final eg. axis: 1 - PADDLE_ENFORCE(((src_tz[0] == dst_tz[0]) && (src_tz[1] == dst_tz[1])), - "Softmax input and output dimensions should match"); + // Same memory descriptor to be used for input and output memory::dims softmax_tz = {src_tz[0], src_tz[1]}; // Currently only supports NC data format diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index fefc7125b4de7274589670d29be4511469d5064a..bb081238820b9ee3ae095442d21cfce11f7b41e5 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -37,10 +37,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SoftmaxOp should not be null."); - auto x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE(x_dims.size() == 2UL, - "The input of softmax op must be a matrix."); - ctx->SetOutputDim("Out", x_dims); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Out"); } @@ -81,8 +78,8 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "The input tensor of softmax. " - "2-D with shape [batch_size, input_feature_dimensions]."); + "The input tensor of softmax, " + "whose last dimension is the input_feature_dimensions."); AddOutput("Out", "The normalized values with the same shape as X.") .Reuse("X"); AddAttr( @@ -105,20 +102,23 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Softmax Operator. -The input of the softmax operator is a 2-D tensor with shape N x K (N is the -batch_size, K is the dimension of input feature). The output tensor has the -same shape as the input tensor. +The input of the softmax operator is a tensor of any rank. The output tensor +has the same shape as the input. -For each row of the input tensor, the softmax operator squashes the -K-dimensional vector of arbitrary real values to a K-dimensional vector of real -values in the range [0, 1] that add up to 1. +The input tensor will first be logically flattened to a 2-D matrix. The matrix's +second dimension(row length) is as same as the last dimension of the input +tensor, and the first dimension(column length) is the product of all other +dimensions of the input tensor. For each row of the matrix, the softmax operator +squashes the K-dimensional(K is the width of the matrix, which is also the size +of the input tensor's last dimension) vector of arbitrary real values to a +K-dimensional vector of real values in the range [0, 1] that add up to 1. It computes the exponential of the given dimension and the sum of exponential values of all the other dimensions in the K-dimensional vector input. Then the ratio of the exponential of the given dimension and the sum of exponential values of all the other dimensions is the output of the softmax operator. -For each row $i$ and each column $j$ in Input(X), we have: +For each row $i$ and each column $j$ in the matrix, we have: $$Out[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$ )DOC"); diff --git a/paddle/fluid/operators/softmax_op.h b/paddle/fluid/operators/softmax_op.h index 600da45a0bbb69b76d59c981e195fc03a49b0504..1205bd0587f32caae04c27ecea581fc17988507f 100644 --- a/paddle/fluid/operators/softmax_op.h +++ b/paddle/fluid/operators/softmax_op.h @@ -31,8 +31,16 @@ class SoftmaxKernel : public framework::OpKernel { // allocate memory on device. Out->mutable_data(context.GetPlace()); + auto dims = X->dims(); + auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); + framework::LoDTensor flattened_x; + framework::LoDTensor flattened_out; + flattened_x.ShareDataWith(*X).Resize(flattened_dims); + flattened_out.ShareDataWith(*Out).Resize(flattened_dims); + math::SoftmaxFunctor()( - context.template device_context(), X, Out); + context.template device_context(), &flattened_x, + &flattened_out); } }; @@ -47,8 +55,18 @@ class SoftmaxGradKernel : public framework::OpKernel { // allocate memory on device. dX->mutable_data(context.GetPlace()); + auto dims = Out->dims(); + auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); + framework::LoDTensor flattened_out; + framework::LoDTensor flattened_d_out; + framework::LoDTensor flattened_d_x; + flattened_out.ShareDataWith(*Out).Resize(flattened_dims); + flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims); + flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims); + math::SoftmaxGradFunctor()( - context.template device_context(), Out, dOut, dX); + context.template device_context(), &flattened_out, + &flattened_d_out, &flattened_d_x); } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 058acd4a50ef54cea724a742d40eaca8f569a21c..12e7170fc3da83071f4a23b6c39463d8c2543391 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1313,13 +1313,16 @@ def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True): def softmax(input, param_attr=None, bias_attr=None, use_cudnn=True, name=None): """ - The input of the softmax layer is a 2-D tensor with shape N x K (N is the - batch_size, K is the dimension of input feature). The output tensor has the - same shape as the input tensor. + The input of the softmax operator is a tensor of any rank. The output tensor + has the same shape as the input. - For each row of the input tensor, the softmax operator squashes the - K-dimensional vector of arbitrary real values to a K-dimensional vector of real - values in the range [0, 1] that add up to 1. + The input tensor will first be logically flattened to a 2-D matrix. The matrix's + second dimension(row length) is as same as the last dimension of the input + tensor, and the first dimension(column length) is the product of all other + dimensions of the input tensor. For each row of the matrix, the softmax operator + squashes the K-dimensional(K is the width of the matrix, which is also the size + of the input tensor's last dimension) vector of arbitrary real values to a + K-dimensional vector of real values in the range [0, 1] that add up to 1. It computes the exponential of the given dimension and the sum of exponential values of all the other dimensions in the K-dimensional vector input. @@ -1327,7 +1330,7 @@ def softmax(input, param_attr=None, bias_attr=None, use_cudnn=True, name=None): exponential values of all the other dimensions is the output of the softmax operator. - For each row :math:`i` and each column :math:`j` in Input(X), we have: + For each row :math:`i` and each column :math:`j` in the matrix, we have: .. math:: diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py index e16ab1d15f165bd0efa1b7d51add36c3020a1910..ad0d555198c36c12fd1cc39c41d39b24b40f64c3 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_op.py @@ -35,6 +35,22 @@ class TestLookupTableOp(OpTest): self.check_grad(['W'], 'Out', no_grad_set=set('Ids')) +class TestLookupTableOpWithTensorIds(OpTest): + def setUp(self): + self.op_type = "lookup_table" + table = np.random.random((17, 31)).astype("float32") + ids = np.random.randint( + low=0, high=17, size=(2, 4, 5, 1)).astype("int64") + self.inputs = {'W': table, 'Ids': ids} + self.outputs = {'Out': table[ids.flatten()].reshape((2, 4, 5, 31))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['W'], 'Out', no_grad_set=set('Ids')) + + class TestLookupTableOpWithPadding(TestLookupTableOp): def test_check_output(self): ids = np.squeeze(self.inputs['Ids']) @@ -44,21 +60,34 @@ class TestLookupTableOpWithPadding(TestLookupTableOp): self.check_output() def test_check_grad(self): - # Since paddings are not trainable and fixed in forward, the gradient of + # Since paddings are not trainable and fixed in forward, the gradient of # paddings makes no sense and we don't test the gradient here. pass -class TestLookupTableWIsSelectedRows(OpTest): - def check_with_place(self, place): - scope = core.Scope() +class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds): + def test_check_output(self): + ids = self.inputs['Ids'] + flatten_idx = ids.flatten() + padding_idx = np.random.choice(flatten_idx, 1)[0] + self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) + self.attrs = {'padding_idx': long(padding_idx)} + self.check_output() + + def test_check_grad(self): + # Since paddings are not trainable and fixed in forward, the gradient of + # paddings makes no sense and we don't test the gradient here. + pass - # create and initialize Id Variable + +class TestLookupTableWIsSelectedRows(OpTest): + def prepare_ids(self, scope, place): ids_tensor = scope.var('Ids').get_tensor() ids_array = np.array([[0], [4], [3], [5]]).astype("int64") ids_tensor.set(ids_array, place) + return ids_array - # create and initialize W Variable + def prepare_w(self, scope, place): rows = [0, 1, 2, 3, 4, 5, 6] row_numel = 12 @@ -71,8 +100,22 @@ class TestLookupTableWIsSelectedRows(OpTest): w_tensor = w_selected_rows.get_tensor() w_tensor.set(w_array, place) - # create Out Variable - out_tensor = scope.var('Out').get_tensor() + def create_out_tensor(self, scope, place): + return scope.var('Out').get_tensor() + + def check_result(self, ids_array, result_array): + # all(): return True if all elements of the iterable are true (or if the iterable is empty) + for idx, row in enumerate(ids_array): + assert (row[0] == result_array[idx]).all() + + def check_with_place(self, place): + scope = core.Scope() + + ids_array = self.prepare_ids(scope, place) + + self.prepare_w(scope, place) + + out_tensor = self.create_out_tensor(scope, place) # create and run lookup_table operator lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out') @@ -80,9 +123,8 @@ class TestLookupTableWIsSelectedRows(OpTest): # get result from Out result_array = np.array(out_tensor) - # all(): return True if all elements of the iterable are true (or if the iterable is empty) - for idx, row in enumerate(ids_array): - assert (row[0] == result_array[idx]).all() + + self.check_result(ids_array, result_array) def test_w_is_selected_rows(self): places = [core.CPUPlace()] @@ -91,5 +133,19 @@ class TestLookupTableWIsSelectedRows(OpTest): self.check_with_place(place) +class TestLookupTableWithTensorIdsWIsSelectedRows( + TestLookupTableWIsSelectedRows): + def prepare_ids(self, scope, place): + ids_tensor = scope.var('Ids').get_tensor() + ids_array = np.random.randint( + low=0, high=6, size=(2, 4, 3, 1)).astype("int64") + ids_tensor.set(ids_array, place) + return ids_array + + def check_result(self, ids_array, result_array): + for idx, row in np.ndenumerate(ids_array): + assert (row == result_array[idx]).all() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 0ab581cfb0ea0ff2205450b8e62edb8bf3c51707..70ad05597c4a160cf6a25aeb3c379320cef69c63 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -26,15 +26,22 @@ def stable_softmax(x): class TestSoftmaxOp(OpTest): + def get_x_shape(self): + return [10, 10] + def setUp(self): self.op_type = "softmax" self.use_cudnn = False self.use_mkldnn = False self.dtype = np.float32 self.init_kernel_type() + self.shape = self.get_x_shape() + + x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + out = np.apply_along_axis(stable_softmax, 1, + x.reshape([-1, self.shape[-1]])) + out = out.reshape(self.shape) - x = np.random.uniform(0.1, 1, [10, 10]).astype(self.dtype) - out = np.apply_along_axis(stable_softmax, 1, x) self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} self.attrs = { @@ -63,6 +70,11 @@ class TestSoftmaxOp(OpTest): self.check_grad(["X"], "Out", max_relative_error=0.01) +class TestSoftmaxOp2(TestSoftmaxOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestSoftmaxCUDNNOp(TestSoftmaxOp): @@ -70,6 +82,13 @@ class TestSoftmaxCUDNNOp(TestSoftmaxOp): self.use_cudnn = True +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp2(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestSoftmaxFP16Op(TestSoftmaxOp): @@ -83,6 +102,13 @@ class TestSoftmaxFP16Op(TestSoftmaxOp): self.check_output_with_place(place, atol=1e-3) +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxFP16Op2(TestSoftmaxFP16Op): + def get_x_shape(self): + return [2, 3, 4, 5] + + @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp): @@ -97,10 +123,22 @@ class TestSoftmaxFP16CUDNNOp(TestSoftmaxOp): self.check_output_with_place(place, atol=1e-3) +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + class TestSoftmaxMKLDNNOp(TestSoftmaxOp): def init_kernel_type(self): self.use_mkldnn = True +class TestSoftmaxMKLDNNOp2(TestSoftmaxMKLDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + if __name__ == "__main__": unittest.main()