diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst index 2ae68d01d34a8cc786ab04d8325b8378c3b037ea..550b0e5b82609750ccd318eee889313cb2d7925a 100644 --- a/doc/api/v2/fluid/layers.rst +++ b/doc/api/v2/fluid/layers.rst @@ -514,3 +514,8 @@ l2_normalize ------------ .. autofunction:: paddle.v2.fluid.layers.l2_normalize :noindex: + +sequence_reshape +---------------- +.. autofunction:: paddle.v2.fluid.layers.sequence_reshape + :noindex: diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 84c010df7c396fc21904ae3c980f5fad70b2ceac..831b1e2a1e10777d9e89364adcd4b1f367e86080 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -485,9 +485,15 @@ void OperatorWithKernel::Run(const Scope& scope, // } auto expected_kernel_key = this->GetExpectedKernelType(ctx); - VLOG(3) << "expected_kernel_key:" << expected_kernel_key; + auto kernel_iter = kernels.find(expected_kernel_key); + if (kernel_iter == kernels.end()) { + PADDLE_THROW("op %s does not have kernel for %s", type_, + KernelTypeToString(expected_kernel_key)); + } + + // do data transform Scope& new_scope = scope.NewScope(); for (auto& var_name_item : this->Inputs()) { @@ -520,8 +526,6 @@ void OperatorWithKernel::Run(const Scope& scope, } } - auto kernel_iter = kernels.find(expected_kernel_key); - auto* new_dev_ctx = pool.Get(expected_kernel_key.place_); kernel_iter->second->Compute( ExecutionContext(*this, new_scope, *new_dev_ctx)); diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 8d1479bdd6311709baaf2a6c673db3d0de4610f8..e007d71fbc9df27765f00161461c257fb18c2c2f 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -34,9 +34,7 @@ limitations under the License. */ namespace paddle { namespace operators { -constexpr int kCondStart = 0; -constexpr int kCondRunning = 1; -constexpr int kCondDone = 2; +constexpr char kOptimizeBlock[] = "OptimizeBlock"; void RunServer(std::shared_ptr service) { service->RunSyncUpdate(); @@ -99,10 +97,8 @@ class RecvOp : public framework::OperatorBase { auto fan_in = Attr("Fanin"); size_t param_count = param_list.size(); - std::string program_str = Attr("OptimizeProgram"); - framework::proto::ProgramDesc program_desc; - program_desc.ParseFromString(program_str); - framework::ProgramDesc program(program_desc); + auto *block = Attr(kOptimizeBlock); + auto *program = block->Program(); framework::Executor executor(dev_place); // TODO(typhoonzero): change this to a while_op for every cluster-batch. @@ -142,8 +138,9 @@ class RecvOp : public framework::OperatorBase { if (exit_flag) { break; } + try { - executor.Run(program, &recv_scope, 0, /*global_block*/ + executor.Run(*program, &recv_scope, block->ID(), /*global_block*/ false /*create_local_scope*/, false /*create_vars*/); } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); @@ -175,8 +172,8 @@ This operator will recv tensor from send_op "IP address to listen on.") .SetDefault("127.0.0.1:6164") .AddCustomChecker([](const std::string &ip) { return !ip.empty(); }); - AddAttr("OptimizeProgram", "type string", - "Serialized ProgramDesc string for recv to run."); + AddAttr( + kOptimizeBlock, "Serialized ProgramDesc string for recv to run."); AddAttr>( "ParamList", "type list of string", "grad->param name mapping to find which param to optimize.") diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc index f634ebe9a2a4648bd08f00af635ef22e8d86a8de..c0e614743a894dece2cdc395d0b28df7e86e921d 100644 --- a/paddle/operators/scale_op.cc +++ b/paddle/operators/scale_op.cc @@ -48,7 +48,7 @@ Scale operator $$Out = scale*X$$ )DOC"); AddAttr("scale", - "(float, default 0)" + "(float, default 1.0)" "The scaling factor of the scale operator.") .SetDefault(1.0); } diff --git a/paddle/operators/send_recv_op_test.cc b/paddle/operators/send_recv_op_test.cc index ea091694798475dfd9631910a750405be950c20c..045a0f5434f339bab345d14881ed05450ce6588d 100644 --- a/paddle/operators/send_recv_op_test.cc +++ b/paddle/operators/send_recv_op_test.cc @@ -130,10 +130,7 @@ void StartServerNet(bool is_sparse) { attrs.insert({"endpoint", std::string("127.0.0.1:6174")}); attrs.insert({"ParamList", std::vector({"Out"})}); attrs.insert({"GradList", std::vector({"x1"})}); - std::string program_proto; - PADDLE_ENFORCE(program.Proto()->SerializeToString(&program_proto)); - - attrs.insert({"OptimizeProgram", program_proto}); + attrs.insert({"OptimizeBlock", block}); recv_op = f::OpRegistry::CreateOp("recv", {{"RX", {"x1"}}}, {}, attrs); recv_op->Run(scope, place); } diff --git a/paddle/operators/sequence_reshape_op.cc b/paddle/operators/sequence_reshape_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..884c49276cead5b329217fceb174d40b2d632025 --- /dev/null +++ b/paddle/operators/sequence_reshape_op.cc @@ -0,0 +1,130 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_reshape_op.h" +#include "paddle/framework/ddim.h" + +namespace paddle { +namespace operators { + +class SequenceReshapeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceReshapeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SequenceReshapeOp should not be null."); + auto x_dims = ctx->GetInputDim("X"); + auto x_numel = product(x_dims); + PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2."); + int new_dim = ctx->Attrs().Get("new_dim"); + ctx->SetOutputDim("Out", + {x_numel / new_dim, static_cast(new_dim)}); + } +}; + +class SequenceReshapeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SequenceReshapeOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(LoDTensor, default LoDTensor) A 2-D LoDTensor with shape " + "being [N, M]."); + AddOutput("Out", + "(LoDTensor, default LoDTensor) A 2-D LoDTensor with " + "shape [T, new_dim] where T is calculated based on X.lod, M and " + "new_dim."); + AddAttr("new_dim", "Sequence dimension of the output LoDTensor."); + AddComment(R"DOC( +Sequence Reshape Operator. + +This operator will rearrange the input sequences. The new dimension is set by +attribute and length of each sequence may change longer or shorter which is +decided by original length, original dimension and new dimension. The following +example will help to illustrate the function of this operator: + +x is a LoDTensor: + x.lod = [[0, 2, 6]] + x.data = [[1, 2], [3, 4], + [5, 6], [7, 8], [9, 10], [11, 12]] + x.dims = [6, 2] + +set new_dim = 4 + +then out is a LoDTensor: + out.lod = [[0, 1, 3]] + out.data = [[1, 2, 3, 4], + [5, 6, 7, 8], [9, 10, 11, 12]] + out.dims = [3, 4] + +Currently, only 1-level LoDTensor is supported and please make sure (original +length * original dimension) can be divided by new_dim with no remainder for +each sequence. + +)DOC"); + } +}; + +class SequenceReshapeGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) of SequenceReshapeGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SequenceReshapeGradOp should not be null."); + + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ framework::GradVarName("X")); + } +}; + +class SequenceReshapeGradOpMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op_desc_ptr = new framework::OpDesc(); + op_desc_ptr->SetType("sequence_reshape_grad"); + op_desc_ptr->SetInput("X", Input("X")); + op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op_desc_ptr->SetAttrMap(Attrs()); + return std::unique_ptr(op_desc_ptr); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(sequence_reshape, ops::SequenceReshapeOp, + ops::SequenceReshapeOpMaker, ops::SequenceReshapeGradOpMaker); +REGISTER_OPERATOR(sequence_reshape_grad, ops::SequenceReshapeGradOp); +REGISTER_OP_CPU_KERNEL( + sequence_reshape, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel); +REGISTER_OP_CPU_KERNEL( + sequence_reshape_grad, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel); diff --git a/paddle/operators/sequence_reshape_op.cu b/paddle/operators/sequence_reshape_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..d9c2f7e9a4149371867cf2a8b81d58566999bfba --- /dev/null +++ b/paddle/operators/sequence_reshape_op.cu @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/sequence_reshape_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + sequence_reshape, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel, + ops::SequenceReshapeKernel); +REGISTER_OP_CUDA_KERNEL( + sequence_reshape_grad, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel, + ops::SequenceReshapeGradKernel); diff --git a/paddle/operators/sequence_reshape_op.h b/paddle/operators/sequence_reshape_op.h new file mode 100644 index 0000000000000000000000000000000000000000..dd9b611250bf61f82ee23a31717cb4363f0c388e --- /dev/null +++ b/paddle/operators/sequence_reshape_op.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using LoDTensor = framework::LoDTensor; +template +class SequenceReshapeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in = context.Input("X"); + auto* out = context.Output("Out"); + int out_width = context.Attr("new_dim"); + + auto in_dims = in->dims(); + int64_t in_width = in_dims[1]; + auto& in_lod = in->lod(); + + PADDLE_ENFORCE_EQ(in_lod.size(), 1UL, + "Only support one level sequence now."); + PADDLE_ENFORCE_EQ( + in_dims[0], in_lod[0].back(), + "Inconsistent size between X.shape[0] and X.lod()[0].back()."); + + auto in_lod_l0 = in_lod[0]; + int seq_num = in_lod_l0.size() - 1; + + if (in_width == out_width) { + out->set_lod(in->lod()); + } else { + auto& out_lod = *out->mutable_lod(); + out_lod.resize(1); + out_lod[0].resize(seq_num + 1); + out_lod[0][0] = 0; + for (int i = 0; i < seq_num; ++i) { + size_t seq_len = in_lod_l0[i + 1] - in_lod_l0[i]; + size_t offset = 0; + offset = (seq_len * in_width) / out_width; + PADDLE_ENFORCE_EQ(offset * out_width, seq_len * in_width, + "Please make sure (sequence_length * dimension) can " + "be divided by new_dim with no remainder for each " + "sequence. The %dth sequence is invalid.", + i + 1); + out_lod[0][i + 1] = out_lod[0][i] + offset; + } + } + + framework::Copy(*in, context.GetPlace(), out); + out->Resize({static_cast(out->lod()[0].back()), out_width}); + } +}; + +template +class SequenceReshapeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x_tensor_ptr = context.Input("X"); + auto* outg_tensor_ptr = + context.Input(framework::GradVarName("Out")); + auto* xg_tensor_ptr = + context.Output(framework::GradVarName("X")); + + xg_tensor_ptr->mutable_data(context.GetPlace()); + framework::Copy(*outg_tensor_ptr, context.GetPlace(), xg_tensor_ptr); + xg_tensor_ptr->Resize(x_tensor_ptr->dims()); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index bd957f88de5d51a2fa3e482284e2d8080f1be76b..02a0e4cd2639e857bce07afa9858531e8d177ad0 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -452,7 +452,7 @@ class DistributeTranspiler: }, # grads to recv outputs={}, attrs={ - "OptimizeProgram": optimize_sub_program.desc, + "OptimizeBlock": optimize_sub_program.global_block(), "endpoint": endpoint, "ParamList": [ p.name diff --git a/python/paddle/v2/fluid/distribute_transpiler_simple.py b/python/paddle/v2/fluid/distribute_transpiler_simple.py index bd88f02bde0c6a58138e20db2b07cbd06cd40ba3..56ffb56b1247646903485e5859b60f63df9b97a2 100644 --- a/python/paddle/v2/fluid/distribute_transpiler_simple.py +++ b/python/paddle/v2/fluid/distribute_transpiler_simple.py @@ -243,7 +243,7 @@ class SimpleDistributeTranspiler: self.param_grad_map[endpoint]["grads"]}, # grads to recv outputs={}, attrs={ - "OptimizeProgram": optimize_sub_program.desc, + "OptimizeBlock": optimize_sub_program.global_block(), "endpoint": endpoint, "ParamList": [p.name for p in self.param_grad_map[endpoint]["params"]], diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index 2314ae04fdfcea809e5be5184d2f90a3d9e90f34..5d05046bbac825e09aeb3dab18b9812edbabd732 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -28,7 +28,8 @@ __all__ = [ 'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand', 'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min', 'sequence_first_step', 'sequence_last_step', 'dropout', 'split', - 'ctc_greedy_decoder', 'edit_distance', 'l2_normalize', 'matmul', 'warpctc' + 'ctc_greedy_decoder', 'edit_distance', 'l2_normalize', 'matmul', 'warpctc', + 'sequence_reshape' ] @@ -213,33 +214,33 @@ def dynamic_lstm(input, (https://arxiv.org/pdf/1402.1128.pdf), the formula is as follows: .. math:: - - i_t & = \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + W_{ic}c_{t-1} + b_i) - f_t & = \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + W_{fc}c_{t-1} + b_f) + i_t & = \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + W_{ic}c_{t-1} + b_i) - \\tilde{c_t} & = act_g(W_{cx}x_t + W_{ch}h_{t-1} + b_c) + f_t & = \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + W_{fc}c_{t-1} + b_f) - o_t & = \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + W_{oc}c_t + b_o) + \\tilde{c_t} & = act_g(W_{cx}x_t + W_{ch}h_{t-1} + b_c) - c_t & = f_t \odot c_{t-1} + i_t \odot \\tilde{c_t} + o_t & = \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + W_{oc}c_t + b_o) + + c_t & = f_t \odot c_{t-1} + i_t \odot \\tilde{c_t} h_t & = o_t \odot act_h(c_t) - where the :math:`W` terms denote weight matrices (e.g. :math:`W_{xi}` is + where the :math:`W` terms denote weight matrices (e.g. :math:`W_{xi}` is the matrix of weights from the input gate to the input), :math:`W_{ic}, \ - W_{fc}, W_{oc}` are diagonal weight matrices for peephole connections. In - our implementation, we use vectors to reprenset these diagonal weight - matrices. The :math:`b` terms denote bias vectors (:math:`b_i` is the input - gate bias vector), :math:`\sigma` is the non-line activations, such as - logistic sigmoid function, and :math:`i, f, o` and :math:`c` are the input - gate, forget gate, output gate, and cell activation vectors, respectively, + W_{fc}, W_{oc}` are diagonal weight matrices for peephole connections. In + our implementation, we use vectors to reprenset these diagonal weight + matrices. The :math:`b` terms denote bias vectors (:math:`b_i` is the input + gate bias vector), :math:`\sigma` is the non-line activations, such as + logistic sigmoid function, and :math:`i, f, o` and :math:`c` are the input + gate, forget gate, output gate, and cell activation vectors, respectively, all of which have the same size as the cell output activation vector :math:`h`. - The :math:`\odot` is the element-wise product of the vectors. :math:`act_g` - and :math:`act_h` are the cell input and cell output activation functions - and `tanh` is usually used for them. :math:`\\tilde{c_t}` is also called - candidate hidden state, which is computed based on the current input and + The :math:`\odot` is the element-wise product of the vectors. :math:`act_g` + and :math:`act_h` are the cell input and cell output activation functions + and `tanh` is usually used for them. :math:`\\tilde{c_t}` is also called + candidate hidden state, which is computed based on the current input and the previous hidden state. Set `use_peepholes` to `False` to disable peephole connection. The formula @@ -251,38 +252,38 @@ def dynamic_lstm(input, Users can choose to use fully-connect layer before LSTM layer. Args: - input(Variable): The input of dynamic_lstm layer, which supports - variable-time length input sequence. The underlying - tensor in this Variable is a matrix with shape - (T X 4D), where T is the total time steps in this + input(Variable): The input of dynamic_lstm layer, which supports + variable-time length input sequence. The underlying + tensor in this Variable is a matrix with shape + (T X 4D), where T is the total time steps in this mini-batch, D is the hidden size. size(int): 4 * hidden size. - param_attr(ParamAttr): The parameter attribute for the learnable - hidden-hidden weights. + param_attr(ParamAttr): The parameter attribute for the learnable + hidden-hidden weights. - - The shape is (D x 4D), where D is the hidden - size. + - The shape is (D x 4D), where D is the hidden + size. - Weights = {:math:`W_{ch}, W_{ih}, \ W_{fh}, W_{oh}`} bias_attr(ParamAttr): The bias attribute for the learnable bias - weights, which contains two parts, input-hidden - bias weights and peephole connections weights if - setting `use_peepholes` to `True`. + weights, which contains two parts, input-hidden + bias weights and peephole connections weights if + setting `use_peepholes` to `True`. - 1. `use_peepholes = False` - - The shape is (1 x 4D). + 1. `use_peepholes = False` + - The shape is (1 x 4D). - Biases = {:math:`b_c, b_i, b_f, b_o`}. - 2. `use_peepholes = True` - - The shape is (1 x 7D). + 2. `use_peepholes = True` + - The shape is (1 x 7D). - Biases = { :math:`b_c, b_i, b_f, b_o, W_{ic}, \ W_{fc}, W_{oc}`}. - use_peepholes(bool): Whether to enable diagonal/peephole connections, + use_peepholes(bool): Whether to enable diagonal/peephole connections, default `True`. is_reverse(bool): Whether to compute reversed LSTM, default `False`. - gate_activation(str): The activation for input gate, forget gate and - output gate. Choices = ["sigmoid", "tanh", "relu", + gate_activation(str): The activation for input gate, forget gate and + output gate. Choices = ["sigmoid", "tanh", "relu", "identity"], default "sigmoid". - cell_activation(str): The activation for cell output. Choices = ["sigmoid", + cell_activation(str): The activation for cell output. Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh". candidate_activation(str): The activation for candidate hidden state. Choices = ["sigmoid", "tanh", "relu", "identity"], @@ -2027,3 +2028,57 @@ def warpctc(input, label, blank=0, norm_by_times=False, **kwargs): attrs={'blank': blank, 'norm_by_times': norm_by_times}) return loss_out + + +def sequence_reshape(input, new_dim): + """ + **Sequence Reshape Layer** + + This layer will rearrange the input sequences. The new dimension is set by + user. Length of each sequence is computed according to original length, + original dimension and new dimension. The following example will help to + illustrate the function of this layer: + + .. code-block:: text + + x is a LoDTensor: + x.lod = [[0, 2, 6]] + x.data = [[1, 2], [3, 4], + [5, 6], [7, 8], [9, 10], [11, 12]] + x.dims = [6, 2] + + set new_dim = 4 + + then out is a LoDTensor: + out.lod = [[0, 1, 3]] + out.data = [[1, 2, 3, 4], + [5, 6, 7, 8], [9, 10, 11, 12]] + out.dims = [3, 4] + + Currently, only 1-level LoDTensor is supported and please make sure + (original length * original dimension) can be divided by new dimension with + no remainder for each sequence. + + Args: + input (Variable): (LodTensor, default: LoDTensor), a 2-D LoDTensor + with shape being [N, M] where M for dimension. + new_dim (int): New dimension which the input LoDTensor is reshaped to. + + Returns: + Variable: Reshaped LoDTensor according to new dimension. + + Examples: + .. code-block:: python + + x = fluid.layers.data(name='x', shape=[5, 20], + dtype='float32', lod_level=1) + x_reshaped = layers.sequence_reshape(input=x, new_dim=10) + """ + helper = LayerHelper('sequence_reshape', **locals()) + out = helper.create_tmp_variable(helper.input_dtype()) + helper.append_op( + type='sequence_reshape', + inputs={'X': [input]}, + outputs={'Out': [out]}, + attrs={'new_dim': new_dim}) + return out diff --git a/python/paddle/v2/fluid/tests/book_distribute/notest_dist_image_classification.py b/python/paddle/v2/fluid/tests/book_distribute/notest_dist_image_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..218dea31e10757d901c5524567f13501b64dbea5 --- /dev/null +++ b/python/paddle/v2/fluid/tests/book_distribute/notest_dist_image_classification.py @@ -0,0 +1,173 @@ +#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import print_function + +import sys + +import paddle.v2 as paddle +import paddle.v2.fluid as fluid +import os +import sys + +TRAINERS = 5 +BATCH_SIZE = 128 +PASS_NUM = 100 + + +def resnet_cifar10(input, depth=32): + def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'): + tmp = fluid.layers.conv2d( + input=input, + filter_size=filter_size, + num_filters=ch_out, + stride=stride, + padding=padding, + act=None, + bias_attr=False) + return fluid.layers.batch_norm(input=tmp, act=act) + + def shortcut(input, ch_in, ch_out, stride): + if ch_in != ch_out: + return conv_bn_layer(input, ch_out, 1, stride, 0, None) + else: + return input + + def basicblock(input, ch_in, ch_out, stride): + tmp = conv_bn_layer(input, ch_out, 3, stride, 1) + tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None) + short = shortcut(input, ch_in, ch_out, stride) + return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') + + def layer_warp(block_func, input, ch_in, ch_out, count, stride): + tmp = block_func(input, ch_in, ch_out, stride) + for i in range(1, count): + tmp = block_func(tmp, ch_out, ch_out, 1) + return tmp + + assert (depth - 2) % 6 == 0 + n = (depth - 2) / 6 + conv1 = conv_bn_layer( + input=input, ch_out=16, filter_size=3, stride=1, padding=1) + res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) + res2 = layer_warp(basicblock, res1, 16, 32, n, 2) + res3 = layer_warp(basicblock, res2, 32, 64, n, 2) + pool = fluid.layers.pool2d( + input=res3, pool_size=8, pool_type='avg', pool_stride=1) + return pool + + +def vgg16_bn_drop(input): + def conv_block(input, num_filter, groups, dropouts): + return fluid.nets.img_conv_group( + input=input, + pool_size=2, + pool_stride=2, + conv_num_filter=[num_filter] * groups, + conv_filter_size=3, + conv_act='relu', + conv_with_batchnorm=True, + conv_batchnorm_drop_rate=dropouts, + pool_type='max') + + conv1 = conv_block(input, 64, 2, [0.3, 0]) + conv2 = conv_block(conv1, 128, 2, [0.4, 0]) + conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0]) + conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0]) + conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0]) + + drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5) + fc1 = fluid.layers.fc(input=drop, size=512, act=None) + bn = fluid.layers.batch_norm(input=fc1, act='relu') + drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5) + fc2 = fluid.layers.fc(input=drop2, size=512, act=None) + return fc2 + + +classdim = 10 +data_shape = [3, 32, 32] + +images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') +label = fluid.layers.data(name='label', shape=[1], dtype='int64') + +net_type = "vgg" +if len(sys.argv) >= 2: + net_type = sys.argv[1] + +if net_type == "vgg": + print("train vgg net") + net = vgg16_bn_drop(images) +elif net_type == "resnet": + print("train resnet") + net = resnet_cifar10(images, 32) +else: + raise ValueError("%s network is not supported" % net_type) + +predict = fluid.layers.fc(input=net, size=classdim, act='softmax') +cost = fluid.layers.cross_entropy(input=predict, label=label) +avg_cost = fluid.layers.mean(x=cost) + +optimizer = fluid.optimizer.Adam(learning_rate=0.001) +optimize_ops, params_grads = optimizer.minimize(avg_cost) + +accuracy = fluid.evaluator.Accuracy(input=predict, label=label) + +train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.cifar.train10(), buf_size=128 * 10), + batch_size=BATCH_SIZE) + +place = fluid.CPUPlace() +exe = fluid.Executor(place) + +t = fluid.DistributeTranspiler() +# all parameter server endpoints list for spliting parameters +pserver_endpoints = os.getenv("PSERVERS") +# server endpoint for current node +current_endpoint = os.getenv("SERVER_ENDPOINT") +# run as trainer or parameter server +training_role = os.getenv("TRAINING_ROLE", + "TRAINER") # get the training role: trainer/pserver +t.transpile( + optimize_ops, params_grads, pservers=pserver_endpoints, trainers=TRAINERS) + +if training_role == "PSERVER": + if not current_endpoint: + print("need env SERVER_ENDPOINT") + exit(1) + print("start pserver at:", current_endpoint) + pserver_prog = t.get_pserver_program(current_endpoint) + pserver_startup = t.get_startup_program(current_endpoint, pserver_prog) + exe.run(pserver_startup) + exe.run(pserver_prog) + print("pserver run end") +elif training_role == "TRAINER": + print("start trainer") + trainer_prog = t.get_trainer_program() + feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) + exe.run(fluid.default_startup_program()) + for pass_id in range(PASS_NUM): + accuracy.reset(exe) + for data in train_reader(): + loss, acc = exe.run(trainer_prog, + feed=feeder.feed(data), + fetch_list=[avg_cost] + accuracy.metrics) + pass_acc = accuracy.eval(exe) + print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( + pass_acc)) + # this model is slow, so if we can train two mini batch, we think it works properly. + print("trainer run end") +else: + print("environment var TRAINER_ROLE should be TRAINER os PSERVER") +exit(1) diff --git a/python/paddle/v2/fluid/tests/test_layers.py b/python/paddle/v2/fluid/tests/test_layers.py index a4e155b534a41e385167e6a6f01e32cfedf580e2..b366e5ba3681cee4552090f08877c46b6b7baaa3 100644 --- a/python/paddle/v2/fluid/tests/test_layers.py +++ b/python/paddle/v2/fluid/tests/test_layers.py @@ -216,6 +216,14 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(x) print(str(program)) + def test_sequence_reshape(self): + program = Program() + with program_guard(program): + x = layers.data(name='x', shape=[8], dtype='float32', lod_level=1) + out = layers.sequence_reshape(input=x, new_dim=16) + self.assertIsNotNone(out) + print(str(program)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_sequence_reshape.py b/python/paddle/v2/fluid/tests/test_sequence_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..857b15237ae872278c1c4f3d1bfe13d1b69fb6b1 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_sequence_reshape.py @@ -0,0 +1,84 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. +import unittest +import numpy as np +import math +from op_test import OpTest + + +class TestSequenceReshape(OpTest): + def setUp(self): + self.op_type = 'sequence_reshape' + dimension = 12 + x_lod = [[0, 4, 5, 8, 11]] + x = np.random.uniform(0.1, 1, [11, 24]).astype('float32') + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'new_dim': dimension} + + out, out_lod = self.compute_output(x, x_lod, dimension) + + self.outputs = {'Out': (out, out_lod)} + + def compute_output(self, x, x_lod, dimension): + x_width = x.shape[1] + out_lod = [[0]] + for i in xrange(len(x_lod[0]) - 1): + seq_len = x_lod[0][i + 1] - x_lod[0][i] + offset = (seq_len * x_width) / dimension + assert int(offset) * dimension == seq_len * x_width + out_lod[0].append(out_lod[0][-1] + int(offset)) + out = np.zeros(shape=(out_lod[0][-1], dimension)).astype('float32') + out.ravel()[:] = x.ravel()[:] + return out, out_lod + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestSequenceReshape_reduce(TestSequenceReshape): + def setUp(self): + self.op_type = 'sequence_reshape' + dimension = 24 + x_lod = [[0, 4, 6, 8, 12]] + x = np.random.uniform(0.1, 1, [12, 12]).astype('float32') + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'new_dim': dimension} + + out, out_lod = self.compute_output(x, x_lod, dimension) + + self.outputs = {'Out': (out, out_lod)} + + +class TestSequenceReshape_same(TestSequenceReshape): + def setUp(self): + self.op_type = 'sequence_reshape' + dimension = 12 + x_lod = [[0, 4, 6, 8, 12]] + x = np.random.uniform(0.1, 1, [12, 12]).astype('float32') + + self.inputs = {'X': (x, x_lod)} + self.attrs = {'new_dim': dimension} + + out, out_lod = self.compute_output(x, x_lod, dimension) + + self.outputs = {'Out': (out, out_lod)} + + +if __name__ == '__main__': + unittest.main()