diff --git a/paddle/fluid/operators/parallel_do_op.cc b/paddle/fluid/operators/parallel_do_op.cc index 1012640d5e2052e4f347ad458cea9072a004f334..c9744db3d0654ef63357963d9a9a3cb946f56e2d 100644 --- a/paddle/fluid/operators/parallel_do_op.cc +++ b/paddle/fluid/operators/parallel_do_op.cc @@ -295,7 +295,7 @@ class ParallelDoGradOp : public framework::OperatorBase { auto sum_op = framework::OpRegistry::CreateOp( "sum", {{"X", {s, tmp_name}}}, {{"Out", {s}}}, - framework::AttributeMap{}); + framework::AttributeMap{{"use_mkldnn", {false}}}); VLOG(10) << sum_op->DebugStringEx(sub_scopes[0]); sum_op->Run(*sub_scopes[0], places[0]); WaitOnPlace(places[0]); diff --git a/paddle/fluid/operators/recurrent_op.cc b/paddle/fluid/operators/recurrent_op.cc index 9c1cee7022a9b9a98f026f7602f0f7badc44a49b..162bfcbb0844d29385d0f8ad5d25a3f8de6bd41b 100644 --- a/paddle/fluid/operators/recurrent_op.cc +++ b/paddle/fluid/operators/recurrent_op.cc @@ -429,7 +429,8 @@ class RecurrentGradOp : public RecurrentBase { auto sum_op = framework::OpRegistry::CreateOp( "sum", {{"X", {pg_names[param_id], new_inside_name}}}, - {{"Out", {pg_names[param_id]}}}, framework::AttributeMap{}); + {{"Out", {pg_names[param_id]}}}, + framework::AttributeMap{{"use_mkldnn", {false}}}); sum_op->Run(cur_scope, place); cur_scope.Rename(new_inside_name, inside_grad_name); diff --git a/paddle/fluid/operators/sum_mkldnn_op.cc b/paddle/fluid/operators/sum_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f78d977760f18c9eb1270e515e68acb208a7c9a4 --- /dev/null +++ b/paddle/fluid/operators/sum_mkldnn_op.cc @@ -0,0 +1,240 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/*Licensed under the Apache License, Version 2.0(the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "mkldnn.hpp" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/operators/sum_op.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; +using paddle::platform::MKLDNNDeviceContext; +using paddle::platform::CPUDeviceContext; +using framework::DataLayout; +using mkldnn::memory; +using mkldnn::primitive; +using mkldnn::stream; +using mkldnn::sum; +using mkldnn::reorder; +using platform::to_void_cast; + +template +class SumMKLDNNOpKernel : public paddle::framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + auto in_vars = ctx.MultiInputVar("X"); + + const int N = in_vars.size(); + auto out_var = ctx.OutputVar("Out"); + bool in_place = out_var == in_vars[0]; + + if (out_var->IsType()) { + LoDTensor* output = ctx.Output("Out"); + T* output_data = output->mutable_data(ctx.GetPlace()); + + std::vector dst_tz = framework::vectorize2int(output->dims()); + auto src_tz = dst_tz; + memory::format output_format{memory::format::format_undef}; + std::vector scales; + std::vector srcs_mpd; + std::vector srcs_mem; + + PADDLE_ENFORCE(in_vars[0]->IsType(), + "Input[0] must be LoDTensors"); + auto& input0 = in_vars[0]->Get(); + PADDLE_ENFORCE(input0.layout() == DataLayout::kMKLDNN && + input0.format() != memory::format::format_undef, + "Wrong layout/format for inputs[0]"); + + memory::format input_format = input0.format(); + + if (src_tz.size() == 1 && (input_format == memory::format::nchw || + input_format == memory::format::nhwc)) { + input_format = memory::format::x; + } + if (src_tz.size() == 2 && (input_format == memory::format::nchw || + input_format == memory::format::nhwc)) { + input_format = memory::format::nc; + } + + for (int i = in_place ? 1 : 0; i < N; i++) { + PADDLE_ENFORCE(in_vars[i]->IsType(), + "all inputs must be all LoDTensors"); + auto& input = in_vars[i]->Get(); + PADDLE_ENFORCE(input.layout() == DataLayout::kMKLDNN && + input.format() != memory::format::format_undef, + "Wrong layout/format for inputs"); + + if (input.numel() == 0) { + continue; + } + + const T* input_data = input.data(); + + auto src_md = + memory::desc(src_tz, memory::data_type::f32, input_format); + auto src_mpd = memory::primitive_desc(src_md, mkldnn_engine); + auto src_mem = memory(src_mpd, to_void_cast(input_data)); + srcs_mpd.push_back(src_mpd); + srcs_mem.push_back(src_mem); + scales.push_back(1.0); + } + + auto dst_md = + memory::desc(dst_tz, memory::data_type::f32, memory::format::any); + + auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd); + + std::shared_ptr dst_mem; + if (in_place) { + dst_mem.reset(new memory(sum_pd.dst_primitive_desc())); + } else { + dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data)); + } + std::vector inputs; + for (size_t i = 0; i < srcs_mem.size(); ++i) { + inputs.push_back(srcs_mem[i]); + } + + auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem); + output_format = (memory::format)platform::GetMKLDNNFormat(sum_pd); + + primitive reorder_prim; + std::shared_ptr target_mem; + if (in_place) { + output_format = input_format; + target_mem.reset(new memory( + {{{src_tz}, memory::data_type::f32, output_format}, mkldnn_engine}, + output_data)); + reorder_prim = reorder(*dst_mem, *target_mem); + } + + std::vector pipeline; + pipeline.push_back(sum_prim); + if (in_place) pipeline.push_back(reorder_prim); + stream(stream::kind::eager).submit(pipeline).wait(); + + output->set_layout(DataLayout::kMKLDNN); + output->set_format(output_format); + } else if (out_var->IsType()) { + // TODO(@mozga-intel) Add MKLDNN SelectedRows support + std::unique_ptr in0; + if (in_place) { + // If is in_place, we store the input[0] to in0 + auto& in_sel0 = in_vars[0]->Get(); + auto& rows = in_sel0.rows(); + in0.reset(new framework::SelectedRows(rows, in_sel0.height())); + in0->mutable_value()->ShareDataWith(in_sel0.value()); + } + + auto get_selected_row = [&](size_t i) -> const SelectedRows& { + if (i == 0 && in0) { + return *in0.get(); + } else { + return in_vars[i]->Get(); + } + }; + auto* out = ctx.Output("Out"); + out->mutable_rows()->clear(); + auto* out_value = out->mutable_value(); + + // Runtime InferShape + size_t first_dim = 0; + for (int i = 0; i < N; i++) { + auto& sel_row = get_selected_row(i); + first_dim += sel_row.rows().size(); + } + auto in_dim = + framework::vectorize(get_selected_row(N - 1).value().dims()); + in_dim[0] = static_cast(first_dim); + + out_value->Resize(framework::make_ddim(in_dim)); + + // if all the input sparse vars are empty, no need to + // merge these vars. + if (first_dim == 0UL) { + return; + } + out_value->mutable_data(ctx.GetPlace()); + math::SelectedRowsAddTo functor; + int64_t offset = 0; + for (int i = 0; i < N; i++) { + auto& sel_row = get_selected_row(i); + if (sel_row.rows().size() == 0) { + continue; + } + PADDLE_ENFORCE_EQ(out->height(), sel_row.height()); + functor(ctx.template device_context(), sel_row, + offset, out); + offset += sel_row.value().numel(); + } + } else if (out_var->IsType()) { + // TODO(@mozga-intel) Add MKLDNN LoDTensorArray support + auto& out_array = *out_var->GetMutable(); + for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) { + PADDLE_ENFORCE(in_vars[i]->IsType(), + "Only support all inputs are TensorArray"); + auto& in_array = in_vars[i]->Get(); + + for (size_t i = 0; i < in_array.size(); ++i) { + if (in_array[i].numel() != 0) { + if (i >= out_array.size()) { + out_array.resize(i + 1); + } + if (out_array[i].numel() == 0) { + framework::TensorCopy(in_array[i], in_array[i].place(), + ctx.device_context(), &out_array[i]); + out_array[i].set_lod(in_array[i].lod()); + } else { + PADDLE_ENFORCE(out_array[i].lod() == in_array[i].lod()); + auto in = EigenVector::Flatten(in_array[i]); + auto result = EigenVector::Flatten(out_array[i]); + result.device(*ctx.template device_context() + .eigen_device()) = result + in; + } + } + } + } + } else { + PADDLE_THROW("Unexpected branch, output variable type is %s", + out_var->Type().name()); + } + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_KERNEL(sum, MKLDNN, ::paddle::platform::CPUPlace, + paddle::operators::SumMKLDNNOpKernel); diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 863baba9ea7663d0b21875e0b423dc4a6ce2d59a..fe7c7039c7dec714e265ede1b7167fd800ddc2f7 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -18,6 +18,10 @@ limitations under the License. */ #include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/operators/detail/safe_ref.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + namespace paddle { namespace operators { using framework::Tensor; @@ -63,6 +67,18 @@ class SumOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto x_vars = ctx.MultiInputVar("X"); + + framework::LibraryType library{framework::LibraryType::kPlain}; + framework::DataLayout layout{framework::DataLayout::kAnyLayout}; + +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } +#endif + if (x_vars[0]->IsType()) { int dtype = -1; for (auto& x_var : x_vars) { @@ -80,26 +96,27 @@ class SumOp : public framework::OperatorWithKernel { "Sum operator should have at least one tensor"); return framework::OpKernelType( - static_cast(dtype), - ctx.device_context()); + static_cast(dtype), ctx.GetPlace(), + layout, library); } else if (x_vars[0]->IsType()) { for (auto& var : x_vars) { auto& value = var->Get().value(); if (value.IsInitialized()) { return framework::OpKernelType(framework::ToDataType(value.type()), - ctx.device_context()); + ctx.device_context(), layout, library); } } // if input sparse vars are not initialized, use an default kernel type. return framework::OpKernelType(framework::proto::VarType::FP32, - ctx.device_context()); + ctx.device_context(), layout, library); } else if (x_vars[0]->IsType()) { for (auto& x_var : x_vars) { auto& array = x_var->Get(); for (auto& each : array) { if (each.numel() != 0) { return framework::OpKernelType(framework::ToDataType(each.type()), - ctx.device_context()); + ctx.device_context(), layout, + library); } } } @@ -116,6 +133,9 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(vector) The input tensors of sum operator.") .AsDuplicable(); AddOutput("Out", "(Tensor) The output tensor of sum operator.").Reuse("X"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( Sum operator. @@ -132,7 +152,6 @@ class SumOpVarTypeInference : public framework::VarTypeInference { framework::BlockDesc* block) const override { auto& inputs = op_desc.Input("X"); auto var_type = framework::proto::VarType::SELECTED_ROWS; - for (auto& name : op_desc.Input("X")) { VLOG(10) << name << " " << block->FindRecursiveOrCreateVar(name).GetType(); @@ -206,6 +225,7 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(sum, ops::SumOp, ops::SumOpMaker, ops::SumGradMaker, ops::SumOpVarTypeInference); + REGISTER_OP_CPU_KERNEL( sum, ops::SumKernel, ops::SumKernel, diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 175c3ac5d79f24e47d21417df8e3eaeb4d5b2335..f440058e8db2024f5c8a0129db3af87a80d6e551 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -203,11 +203,11 @@ class WhileGradOp : public framework::OperatorBase { ->set_lod(inside_tensor.lod()); } } - auto new_inside_name = cur_scope.Rename(inside_grad_name); auto sum_op = framework::OpRegistry::CreateOp( "sum", {{"X", {pg_names[param_id], new_inside_name}}}, - {{"Out", {pg_names[param_id]}}}, framework::AttributeMap{}); + {{"Out", {pg_names[param_id]}}}, + framework::AttributeMap{{"use_mkldnn", {false}}}); sum_op->Run(cur_scope, dev_place); cur_scope.Rename(new_inside_name, inside_grad_name); } diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index de711b7d23ef01d57a62087c552ea090f01f0386..2689d5e0787e0164bfb8e539399d8a378964e50a 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -99,5 +99,11 @@ inline mkldnn::memory::format GetMKLDNNFormat(const mkldnn::memory memory) { memory.get_primitive_desc().desc().data.format); } +inline mkldnn::memory::format GetMKLDNNFormat( + const mkldnn::sum::primitive_desc& memory) { + return static_cast( + memory.dst_primitive_desc().desc().data.format); +} + } // namespace platform } // namespace paddle diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index f7bbc98fe1f41c3a049692d36915d0a722779416..4faa06303170488d0de2fda4c1461cfe2d623d35 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -132,9 +132,9 @@ def _addup_repetitive_outputs_(op_descs): for idx, op_desc in enumerate(op_descs): for var_name in op_desc.input_arg_names(): if len(renamed_vars[var_name]) > 1: - pending_sum_ops.append( - (_create_op_desc_("sum", {"X": renamed_vars[var_name]}, - {"Out": [var_name]}, {}), idx)) + pending_sum_ops.append((_create_op_desc_( + "sum", {"X": renamed_vars[var_name]}, {"Out": [var_name]}, + {"use_mkldnn": False}), idx)) renamed_vars[var_name] = [var_name] for var_name in op_desc.output_arg_names(): if var_name == core.empty_var_name( @@ -161,8 +161,9 @@ def _addup_repetitive_outputs_(op_descs): renamed_vars[var_name].append(new_name) for var_name, inputs in renamed_vars.iteritems(): if len(inputs) > 1: - pending_sum_ops.append((_create_op_desc_( - "sum", {"X": inputs}, {"Out": [var_name]}, {}), len(op_descs))) + pending_sum_ops.append( + (_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]}, + {"use_mkldnn": False}), len(op_descs))) # sum_op descs are sorted according to their insert position for p in reversed(pending_sum_ops): op_descs.insert(p[1], p[0]) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 097fbe982f21fc77efe9fe61dfa27fe71c5f08f7..a87bead14c25f4dd1aa321fd705c59c87ed88490 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -All layers just related to the neural network. +All layers just related to the neural network. """ from ..layer_helper import LayerHelper @@ -109,14 +109,14 @@ def fc(input, """ **Fully Connected Layer** - This function creates a fully connected layer in the network. It can take - multiple tensors as its inputs. It creates a variable called weights for - each input tensor, which represents a fully connected weight matrix from - each input unit to each output unit. The fully connected layer multiplies - each input tensor with its coresponding weight to produce an output Tensor. - If multiple input tensors are given, the results of multiple multiplications - will be sumed up. If bias_attr is not None, a bias variable will be created - and added to the output. Finally, if activation is not None, it will be applied + This function creates a fully connected layer in the network. It can take + multiple tensors as its inputs. It creates a variable called weights for + each input tensor, which represents a fully connected weight matrix from + each input unit to each output unit. The fully connected layer multiplies + each input tensor with its coresponding weight to produce an output Tensor. + If multiple input tensors are given, the results of multiple multiplications + will be sumed up. If bias_attr is not None, a bias variable will be created + and added to the output. Finally, if activation is not None, it will be applied to the output as well. This process can be formulated as follows: @@ -198,7 +198,10 @@ def fc(input, else: pre_bias = helper.create_tmp_variable(dtype) helper.append_op( - type="sum", inputs={"X": mul_results}, outputs={"Out": pre_bias}) + type="sum", + inputs={"X": mul_results}, + outputs={"Out": pre_bias}, + attrs={"use_mkldnn": use_mkldnn}) # add bias pre_activation = helper.append_bias_op(pre_bias, dim_start=num_flatten_dims) # add activation @@ -847,7 +850,7 @@ def crf_decoding(input, param_attr, label=None): Returns: Variable: ${viterbi_path_comment} - + Examples: .. code-block:: python @@ -1085,7 +1088,7 @@ def chunk_eval(input, Here is a NER example of labeling for these tagging schemes: .. code-block:: python - + ====== ====== ====== ===== == ============ ===== ===== ===== == ========= Li Ming works at Agricultural Bank of China in Beijing. ====== ====== ====== ===== == ============ ===== ===== ===== == ========= @@ -1111,7 +1114,7 @@ def chunk_eval(input, is the num of chunk types, and `tag_type` get its value from the following table. .. code-block:: python - + Scheme Begin Inside End Single plain 0 - - - IOB 0 1 - - @@ -1147,7 +1150,7 @@ def chunk_eval(input, tuple: tuple containing: precision, recall, f1_score, num_infer_chunks, num_label_chunks, num_correct_chunks - + Examples: .. code-block:: python @@ -1247,7 +1250,7 @@ def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True): """ This function computes the softmax activation among all time-steps for each sequence. The dimension of each time-step should be 1. Thus, the shape of - input Tensor can be either :math:`[N, 1]` or :math:`[N]`, where :math:`N` + input Tensor can be either :math:`[N, 1]` or :math:`[N]`, where :math:`N` is the sum of the length of all sequences. For i-th sequence in a mini-batch: @@ -1267,7 +1270,7 @@ def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=True): param_attr (ParamAttr|None): attributes for parameter use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn \ library is installed. Default: True - + Returns: Variable: output of sequence_softmax @@ -1828,11 +1831,11 @@ def pool2d(input, ${comment} Args: - input (Variable): The input tensor of pooling operator. The format of - input tensor is NCHW, where N is batch size, C is - the number of channels, H is the height of the + input (Variable): The input tensor of pooling operator. The format of + input tensor is NCHW, where N is batch size, C is + the number of channels, H is the height of the feature, and W is the width of the feature. - pool_size (int): The side length of pooling windows. All pooling + pool_size (int): The side length of pooling windows. All pooling windows are squares with pool_size on a side. pool_type: ${pooling_type_comment} pool_stride (int): stride of the pooling layer. @@ -1841,7 +1844,7 @@ def pool2d(input, use_cudnn: ${use_cudnn_comment} ceil_mode: ${ceil_mode_comment} use_mkldnn: ${use_mkldnn_comment} - name (str|None): A name for this layer(optional). If set None, the + name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. Returns: @@ -1859,10 +1862,10 @@ def pool2d(input, data = fluid.layers.data( name='data', shape=[3, 32, 32], dtype='float32') conv2d = fluid.layers.pool2d( - input=data, - pool_size=2, - pool_type='max', - pool_stride=1, + input=data, + pool_size=2, + pool_type='max', + pool_stride=1, global_pooling=False) """ if pool_type not in ["max", "avg"]: @@ -2227,14 +2230,14 @@ def beam_search_decode(ids, scores, name=None): This layers is to pack the output of beam search layer into sentences and associated scores. It is usually called after the beam search layer. Typically, the output of beam search layer is a tensor of selected ids, with - a tensor of the score of each id. Beam search layer's output ids, however, - are generated directly during the tree search, and they are stacked by each - level of the search tree. Thus we need to reorganize them into sentences, + a tensor of the score of each id. Beam search layer's output ids, however, + are generated directly during the tree search, and they are stacked by each + level of the search tree. Thus we need to reorganize them into sentences, based on the score of each id. This layer takes the output of beam search layer as input and repack them into sentences. Args: - ids (Variable): The selected ids, output of beam search layer. + ids (Variable): The selected ids, output of beam search layer. scores (Variable): The associated scores of the ids, out put of beam search layer. name (str): The name of this layer. It is optional. @@ -2242,7 +2245,7 @@ def beam_search_decode(ids, scores, name=None): Returns: tuple(Variable): a tuple of two output tensors: sentence_ids, sentence_scores. sentence_ids is a tensor with shape [size, length], where size is the - beam size of beam search, and length is the length of each sentence. + beam size of beam search, and length is the length of each sentence. Note that the length of sentences may vary. sentence_scores is a tensor with the same shape as sentence_ids. @@ -2919,7 +2922,7 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None): `None`, compute the mean over all elements of :attr:`input` and return a variable with a single element, otherwise it must be in the range :math:`[-rank(input), rank(input))`. If - :math:`dim[i] < 0`, the dimension to reduce is + :math:`dim[i] < 0`, the dimension to reduce is :math:`rank(input) + dim[i]`. keep_dim (bool): Whether to reserve the reduced dimension in the output Tensor. The result tensor will have one fewer dimension @@ -3390,16 +3393,16 @@ def topk(input, k, name=None): Args: input(Variable): The input variable which can be a vector or Tensor with higher rank. - k(int): The number of top elements to look for along the last dimension + k(int): The number of top elements to look for along the last dimension of input. name(str|None): A name for this layer(optional). If set None, the layer - will be named automatically. + will be named automatically. Default: None Returns: - Tuple[Variable]: A tuple with two elements. Each element is a Variable. - The first one is k largest elements along each last - dimensional slice. The second one is indices of values + Tuple[Variable]: A tuple with two elements. Each element is a Variable. + The first one is k largest elements along each last + dimensional slice. The second one is indices of values within the last dimension of input. Raises: @@ -3594,15 +3597,15 @@ def warpctc(input, label, blank=0, norm_by_times=False): It's shape is [Lp, num_classes + 1], where Lp is the sum of all input sequences' length and num_classes is the true number of classes. (not including the blank label). - label (Variable): The ground truth of variable-length sequence, + label (Variable): The ground truth of variable-length sequence, which is a 2-D Tensor with LoD information. It is of the shape [Lg, 1], where Lg is th sum of all labels' length. blank (int, default 0): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). - norm_by_times(bool, default false): Whether to normalize the gradients - by the number of time-step, which is also the sequence's length. - There is no need to normalize the gradients if warpctc layer was + norm_by_times(bool, default false): Whether to normalize the gradients + by the number of time-step, which is also the sequence's length. + There is no need to normalize the gradients if warpctc layer was follewed by a mean_op. Returns: @@ -3708,8 +3711,8 @@ def nce(input, input (Variable): input variable. label (Variable): label. num_total_classes (int):${num_total_classes_comment} - sample_weight (Variable|None): A Variable of shape [batch_size, 1] - storing a weight for each sample. The default weight for each + sample_weight (Variable|None): A Variable of shape [batch_size, 1] + storing a weight for each sample. The default weight for each sample is 1.0. param_attr (ParamAttr|None): attributes for parameter bias_attr (ParamAttr|None): attributes for bias @@ -4099,7 +4102,7 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): This layer computes the smooth L1 loss for Variable :attr:`x` and :attr:`y`. It takes the first dimension of :attr:`x` and :attr:`y` as batch size. For each instance, it computes the smooth L1 loss element by element first - and then sums all the losses. So the shape of ouput Variable is + and then sums all the losses. So the shape of ouput Variable is [batch_size, 1]. Args: @@ -4108,14 +4111,14 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): y (Variable): A tensor with rank at least 2. The target value of smooth L1 loss op with same shape as :attr:`x`. inside_weight (Variable|None): A tensor with rank at least 2. This - input is optional and should have same shape with :attr:`x`. If - provided, the result of (:attr:`x` - :attr:`y`) will be multiplied + input is optional and should have same shape with :attr:`x`. If + provided, the result of (:attr:`x` - :attr:`y`) will be multiplied by this tensor element by element. outside_weight (Variable|None): A tensor with rank at least 2. This - input is optional and should have same shape with :attr:`x`. If - provided, the out smooth L1 loss will be multiplied by this tensor + input is optional and should have same shape with :attr:`x`. If + provided, the out smooth L1 loss will be multiplied by this tensor element by element. - sigma (float|None): Hyper parameter of smooth L1 loss layer. A float + sigma (float|None): Hyper parameter of smooth L1 loss layer. A float scalar with default value 1.0. Returns: @@ -4161,7 +4164,7 @@ def one_hot(input, depth): Examples: .. code-block:: python - + label = layers.data(name="label", shape=[1], dtype="float32") one_hot_label = layers.one_hot(input=label, depth=10) """ @@ -4315,10 +4318,10 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): def lod_reset(x, y=None, target_lod=None): """ Set LoD of :attr:`x` to a new one specified by :attr:`y` or - :attr:`target_lod`. When :attr:`y` provided, :attr:`y.lod` would be - considered as target LoD first, otherwise :attr:`y.data` would be - considered as target LoD. If :attr:`y` is not provided, target LoD should - be specified by :attr:`target_lod`. If target LoD is specified by + :attr:`target_lod`. When :attr:`y` provided, :attr:`y.lod` would be + considered as target LoD first, otherwise :attr:`y.data` would be + considered as target LoD. If :attr:`y` is not provided, target LoD should + be specified by :attr:`target_lod`. If target LoD is specified by :attr:`Y.data` or :attr:`target_lod`, only one level LoD is supported. .. code-block:: text @@ -4372,7 +4375,7 @@ def lod_reset(x, y=None, target_lod=None): Args: x (Variable): Input variable which could be a Tensor or LodTensor. - y (Variable|None): If provided, output's LoD would be derived + y (Variable|None): If provided, output's LoD would be derived from :attr:`y`. target_lod (list|tuple|None): One level LoD which should be considered as target LoD when :attr:`y` not provided. @@ -4688,7 +4691,7 @@ def image_resize(input, """ **Resize a Batch of Images** - The input must be a tensor of the shape (num_batches, channels, in_h, in_w), + The input must be a tensor of the shape (num_batches, channels, in_h, in_w), and the resizing only applies on the last two dimensions(hight and width). Supporting resample methods: @@ -4784,9 +4787,9 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None): def image_resize_short(input, out_short_len, resample='BILINEAR'): """ - Resize a batch of images. The short edge of input images will be - resized to the given 'out_short_len'. The long edge of input images - will be resized proportionately to make images' length-width ratio + Resize a batch of images. The short edge of input images will be + resized to the given 'out_short_len'. The long edge of input images + will be resized proportionately to make images' length-width ratio constant. Args: @@ -4819,7 +4822,7 @@ def gather(input, index): """ **Gather Layer** - Output is obtained by gathering entries of the outer-most dimension + Output is obtained by gathering entries of the outer-most dimension of X indexed by `index` and concatenate them together. .. math:: @@ -4844,7 +4847,7 @@ def gather(input, index): [5, 6]] Args: - input (Variable): The source input with rank>=1. + input (Variable): The source input with rank>=1. index (Variable): The index input with rank=1. Returns: @@ -4880,7 +4883,7 @@ def random_crop(x, shape, seed=None): Returns: ${out_comment} - + Examples: >>> img = fluid.layers.data("img", [3, 256, 256]) >>> cropped_img = fluid.layers.random_crop(img, shape=[3, 224, 224]) @@ -4926,7 +4929,7 @@ def log(x): Out = \\ln(x) Args: - x (Variable): Input tensor. + x (Variable): Input tensor. Returns: Variable: The natural log of the input tensor computed element-wise. @@ -4955,7 +4958,7 @@ def relu(x): Out = \\max(0, x) Args: - x (Variable): The input tensor. + x (Variable): The input tensor. Returns: Variable: The output tensor with the same shape as input. @@ -4976,15 +4979,15 @@ def relu(x): def mean_iou(input, label, num_classes): """ Mean Intersection-Over-Union is a common evaluation metric for - semantic image segmentation, which first computes the IOU for each - semantic class and then computes the average over classes. - IOU is defined as follows: - + semantic image segmentation, which first computes the IOU for each + semantic class and then computes the average over classes. + IOU is defined as follows: + .. math:: IOU = \\frac{true\_positiv}{(true\_positive + false\_positive + false\_negative)}. - The predictions are accumulated in a confusion matrix and mean-IOU + The predictions are accumulated in a confusion matrix and mean-IOU is then calculated from it. @@ -4997,12 +5000,12 @@ def mean_iou(input, label, num_classes): Returns: mean_iou (Variable): A Tensor representing the mean intersection-over-union with shape [1]. out_wrong(Variable): A Tensor with shape [num_classes]. The wrong numbers of each class. - out_correct(Variable): A Tensor with shape [num_classes]. The correct numbers of each class. + out_correct(Variable): A Tensor with shape [num_classes]. The correct numbers of each class. Examples: .. code-block:: python - + iou, wrongs, corrects = fluid.layers.mean_iou(predict, label, num_classes) """ helper = LayerHelper('mean_iou', **locals()) diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 149e77b52415025e78fbf4ee641151880b415fb0..b7a8bff30d3baffb7ec4d67a9bf6f5b00e3aa983 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -230,7 +230,11 @@ def sums(input, out=None): helper = LayerHelper('sum', **locals()) if out is None: out = helper.create_tmp_variable(dtype=helper.input_dtype()) - helper.append_op(type='sum', inputs={'X': input}, outputs={'Out': out}) + helper.append_op( + type='sum', + inputs={'X': input}, + outputs={'Out': out}, + attrs={'use_mkldnn': False}) return out @@ -380,7 +384,7 @@ def argmin(x, axis=0): """ **argmin** - This function computes the indices of the min elements + This function computes the indices of the min elements of the input tensor's element along the provided axis. Args: @@ -395,7 +399,7 @@ def argmin(x, axis=0): .. code-block:: python out = fluid.layers.argmin(x=in, axis=0) - out = fluid.layers.argmin(x=in, axis=-1) + out = fluid.layers.argmin(x=in, axis=-1) """ helper = LayerHelper("arg_min", **locals()) out = helper.create_tmp_variable(VarDesc.VarType.INT64) @@ -411,7 +415,7 @@ def argmax(x, axis=0): """ **argmax** - This function computes the indices of the max elements + This function computes the indices of the max elements of the input tensor's element along the provided axis. Args: @@ -426,7 +430,7 @@ def argmax(x, axis=0): .. code-block:: python out = fluid.layers.argmax(x=in, axis=0) - out = fluid.layers.argmax(x=in, axis=-1) + out = fluid.layers.argmax(x=in, axis=-1) """ helper = LayerHelper("arg_max", **locals()) out = helper.create_tmp_variable(VarDesc.VarType.INT64) @@ -495,9 +499,9 @@ def reverse(x, axis): Args: x(Vairbale): the input to be reversed. - axis(int|tuple|list): Axis that along which order of elements - is reversed. If it is a tuple or a list, reversing - will be apply on each axis in the tuple or list. + axis(int|tuple|list): Axis that along which order of elements + is reversed. If it is a tuple or a list, reversing + will be apply on each axis in the tuple or list. Returns: Variable: The reversed tensor. @@ -528,9 +532,9 @@ def save(x, file_path, overwrite=True): Args: x(variable): The Tensor/LoDTensor to be saved. file_path(str): The file path where the variable will be saved. - overwrite(bool): Whether or not cover the given file when it has already - existed. If it's set 'False' and the file is existed, a runtime - error will be thrown. + overwrite(bool): Whether or not cover the given file when it has already + existed. If it's set 'False' and the file is existed, a runtime + error will be thrown. """ helper = LayerHelper("save", **locals()) helper.append_op( @@ -550,8 +554,8 @@ def save_combine(x, file_path, overwrite=True): a single file. file_path(str): The file path where variables will be saved. overwrite(bool): Whether or not cover the given file when it has already - existed. If it's set 'False' and the file is existed, a runtime - error will be thrown. + existed. If it's set 'False' and the file is existed, a runtime + error will be thrown. Returns: There is no return value. diff --git a/python/paddle/fluid/tests/unittests/test_sum_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_sum_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7956897d68a3fb49d62ba696d0b6400b4f909989 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_sum_mkldnn_op.py @@ -0,0 +1,26 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from test_sum_op import TestSumOp + + +class TestMKLDNN(TestSumOp): + def init_kernel_type(self): + self.use_mkldnn = True + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index 2faf5b10647a1fa1d44e4847f017db177ee8808a..1d90414e137a70e6265042e24e106fe565802778 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -20,12 +20,15 @@ from op_test import OpTest class TestSumOp(OpTest): def setUp(self): self.op_type = "sum" + self.use_mkldnn = False + self.init_kernel_type() x0 = np.random.random((3, 4)).astype('float32') x1 = np.random.random((3, 4)).astype('float32') x2 = np.random.random((3, 4)).astype('float32') self.inputs = {"X": [("x0", x0), ("x1", x1), ("x2", x2)]} y = x0 + x1 + x2 self.outputs = {'Out': y} + self.attrs = {'use_mkldnn': self.use_mkldnn} def test_check_output(self): self.check_output() @@ -33,6 +36,9 @@ class TestSumOp(OpTest): def test_check_grad(self): self.check_grad(['x0'], 'Out') + def init_kernel_type(self): + pass + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 041f0aa42ffe88b2086ccce1b6e06d33dae2e8af..d8d6a7e9418e1c2a9f82d58b5c9650d58604d46e 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -872,7 +872,8 @@ class DistributeTranspiler(object): table_opt_block.append_op( type="sum", inputs={"X": pserver_side_table_grad_list}, - outputs={"Out": [grad_var]}) + outputs={"Out": [grad_var]}, + attrs={"use_mkldnn": False}) else: # in async_mode, for table gradient, it also need to be splited to each parameter server origin_grad_name = grad_var.name @@ -1104,7 +1105,8 @@ class DistributeTranspiler(object): optimize_block.append_op( type="sum", inputs={"X": vars2merge}, - outputs={"Out": merged_var}) + outputs={"Out": merged_var}, + attrs={"use_mkldnn": False}) # TODO(panyx0718): What if it's SELECTED_ROWS. if not merged_var.type == core.VarDesc.VarType.SELECTED_ROWS: optimize_block.append_op( diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index 44a6e344630bb35d28ee29078bf8727053a24bef..1f83cabb8481451736944823be45185deea4f43b 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -336,7 +336,7 @@ def _buf2lines(buf, line_break="\n"): class PipeReader: """ - PipeReader read data by stream from a command, take it's + PipeReader read data by stream from a command, take it's stdout into a pipe buffer and redirect it to the parser to parse, then yield data as your desired format. @@ -352,7 +352,7 @@ class PipeReader: An example: .. code-block:: python - + def example_reader(): for f in myfiles: pr = PipeReader("cat %s"%f)