diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index bc08c0730be4bf7abf6b022a47a793639368fcb1..5824fa380470626a29c3e3afb36362aba63578fe 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -194,6 +194,8 @@ paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=Non paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None)) paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None)) +paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 832245371e0b1966000ec0252a58ca02193332a7..9c5b8604f40ae56c463b54c71623feb61bd8d297 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -76,8 +76,8 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, } #endif return framework::OpKernelType( - framework::ToDataType(ctx.Input(name)->type()), - ctx.GetPlace(), layout, library); + framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout, + library); } class ActivationOp : public framework::OperatorWithKernel { diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index a0f8c5c14c48cb1e2be60b53a2198e30b050b33d..87d549678a0e6c183aac89539cf1f6331729de2c 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -41,6 +41,12 @@ static std::unordered_set InplaceOpSet = { "floor", "reciprocal", "relu6", "soft_relu", "hard_sigmoid", }; +/* The following operator can be used to process SelectedRows, because the + * output of those operator for zero is zero too. + */ +static std::unordered_set CanBeUsedBySelectedRows = { + "abs", "abs_grad", "square", "square_grad", "sqrt", "sqrt_grad"}; + static bool IsInplace(std::string op) { return InplaceOpSet.count(op); } template @@ -50,16 +56,38 @@ class ActivationKernel using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { - auto& X = detail::Ref(context.Input("X"), - "Cannot get input tensor X, variable name = %s", - context.op().Input("X")); - - auto& Out = detail::Ref(context.Output("Out"), - "Cannot get output tensor Out, variable name = %s", - context.op().Output("Out")); - Out.mutable_data(context.GetPlace()); + auto x_var = context.InputVar("X"); + auto out_var = context.OutputVar("Out"); + PADDLE_ENFORCE(x_var != nullptr, + "Cannot get input Variable X, variable name = %s", + context.op().Input("X")); + PADDLE_ENFORCE(out_var != nullptr, + "Cannot get output Variable Out, variable name = %s", + context.op().Output("Out")); + + framework::Tensor X, *Out; + + if (CanBeUsedBySelectedRows.count(context.op().Type())) { + X = detail::Ref( + paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var), + "Cannot get input Tensor X, variable name = %s", + context.op().Input("X")); + Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( + out_var); + } else { + X = detail::Ref(context.Input("X"), + "Cannot get input Tensor X, variable name = %s", + context.op().Input("X")); + Out = context.Output("Out"); + } + + PADDLE_ENFORCE(Out != nullptr, + "Cannot get output tensor Out, variable name = %s", + context.op().Output("Out")); + + Out->mutable_data(context.GetPlace()); auto x = framework::EigenVector::Flatten(X); - auto out = framework::EigenVector::Flatten(Out); + auto out = framework::EigenVector::Flatten(*Out); auto* place = context.template device_context().eigen_device(); Functor functor; @@ -78,14 +106,54 @@ class ActivationGradKernel public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { - auto* Out = context.Input("Out"); - auto* dOut = - context.Input(framework::GradVarName("Out")); - auto* dX = context.Output(framework::GradVarName("X")); + auto out_var = context.InputVar("Out"); + auto out_grad_var = context.InputVar(framework::GradVarName("Out")); + auto x_grad_var = context.OutputVar(framework::GradVarName("X")); + PADDLE_ENFORCE(out_var != nullptr, + "Cannot get input Variable Out, variable name = %s", + context.op().Input("Out")); + PADDLE_ENFORCE(out_grad_var != nullptr, + "Cannot get input Variable %s, variable name = %s", + framework::GradVarName("Out"), + context.op().Input(framework::GradVarName("Out"))); + PADDLE_ENFORCE(x_grad_var != nullptr, + "Cannot get output Variable %s, variable name = %s", + framework::GradVarName("X"), + context.op().Output(framework::GradVarName("X"))); + + framework::Tensor Out, dOut, *dX; + if (CanBeUsedBySelectedRows.count(context.op().Type())) { + Out = detail::Ref( + paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var), + "Cannot get input Tensor Out, variable name = %s", + context.op().Input("Out")); + dOut = + detail::Ref(paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar( + *out_grad_var), + "Cannot get input Tensor %s, variable name = %s", + framework::GradVarName("Out"), + context.op().Input(framework::GradVarName("Out"))); + dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( + x_grad_var); + } else { + Out = detail::Ref(context.Input("Out"), + "Cannot get input Tensor Out, variable name = %s", + context.op().Input("Out")); + dOut = detail::Ref( + context.Input(framework::GradVarName("Out")), + "Cannot get input Tensor %s, variable name = %s", + framework::GradVarName("Out"), + context.op().Input(framework::GradVarName("Out"))); + dX = context.Output(framework::GradVarName("X")); + } + PADDLE_ENFORCE(dX != nullptr, + "Cannot get output tensor %s, variable name = %s", + framework::GradVarName("X"), + context.op().Output(framework::GradVarName("X"))); dX->mutable_data(context.GetPlace()); - auto dout = framework::EigenVector::Flatten(*dOut); - auto out = framework::EigenVector::Flatten(*Out); + auto dout = framework::EigenVector::Flatten(dOut); + auto out = framework::EigenVector::Flatten(Out); auto dx = framework::EigenVector::Flatten(*dX); auto* place = context.template device_context().eigen_device(); @@ -96,8 +164,19 @@ class ActivationGradKernel } bool inplace = functor.Inplace(); if (!inplace) { - auto* X = context.Input("X"); - auto x = framework::EigenVector::Flatten(*X); + auto x_var = context.InputVar("X"); + PADDLE_ENFORCE(x_var != nullptr, + "Cannot get input tensor X, variable name = %s", + context.op().Input("X")); + framework::Tensor X; + if (CanBeUsedBySelectedRows.count(context.op().Type())) { + X = detail::Ref( + paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var)); + } else { + X = detail::Ref(context.Input("X")); + } + + auto x = framework::EigenVector::Flatten(X); functor(*place, x, out, dout, dx); } else { VLOG(10) << " Inplace activation "; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index dc25bc57103286ce183a4649964fd96c62169b7f..a8b8a67a114b956f2d6b1b072ef343a179114b34 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -60,15 +60,37 @@ template class ElementwiseMulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); + auto x_var = ctx.InputVar("X"); + PADDLE_ENFORCE(x_var != nullptr, + "Cannot get input Variable X, variable name = %s", + ctx.op().Input("X")); auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); + + framework::Tensor x, *z; + if (x_var->IsType()) { + PADDLE_ENFORCE(y->dims().size() == 1 && y->dims()[0] == 1, + "For elementwise_op, if X is Sparse, Y must be scalar."); + auto& x_sele = x_var->Get(); + auto out_sele = ctx.Output("Out"); + x = x_sele.value(); + out_sele->set_rows(x_sele.rows()); + out_sele->set_height(x_sele.height()); + out_sele->mutable_value()->Resize(x_sele.value().dims()); + out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type()); + z = ctx.Output("Out")->mutable_value(); + } else if (x_var->IsType()) { + x = x_var->Get(); + z = ctx.Output("Out"); + } else { + PADDLE_THROW("X's type[%s] is not supported by elementwise_op.", + x_var->Type().name()); + } z->mutable_data(ctx.GetPlace()); - if (x->numel() == y->numel()) { - elementwise_mul(ctx, x, y, z); + if (x.numel() == y->numel()) { + elementwise_mul(ctx, &x, y, z); } else { - default_elementwise_mul(ctx, x, y, z); + default_elementwise_mul(ctx, &x, y, z); } } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 85a7817be9b3a82d40853b417d78a7fdf67f6c1f..87bf7c6b156f32b8f6a1abc30b0676e1d4711d64 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -40,21 +40,28 @@ class ElementwiseOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of elementwise op should not be null."); - PADDLE_ENFORCE( - ctx->GetInputsVarType("X").front() == - framework::proto::VarType::LOD_TENSOR, - "The input var's type should be LoDTensor, but the received is %s", - ctx->Inputs("X").front(), ctx->GetInputsVarType("X").front()); PADDLE_ENFORCE( ctx->GetInputsVarType("Y").front() == framework::proto::VarType::LOD_TENSOR, - "The input var's type should be LoDTensor, but the received is %s", - ctx->Inputs("Y").front(), ctx->GetInputsVarType("Y").front()); - - auto x_dim = ctx->GetInputDim("X"); - auto y_dim = ctx->GetInputDim("Y"); - PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), - "Rank of first input must >= rank of second input."); + "The input var's type should be LoDTensor, but the received is %s [%s]", + ctx->GetInputsVarType("Y").front(), ctx->Inputs("Y").front()); + + if (ctx->GetInputsVarType("X").front() == + framework::proto::VarType::LOD_TENSOR) { + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), + "Rank of first input must >= rank of second input."); + } else if (ctx->GetInputsVarType("X").front() == + framework::proto::VarType::SELECTED_ROWS) { + PADDLE_ENFORCE((ctx->GetInputDim("Y").size() == 1u) && + (ctx->GetInputDim("Y")[0] == 1), + "For elementwise_op, if X is Sparse, " + "Y must be scalar."); + } else { + PADDLE_THROW("X's type[%s] is not supported by elementwise_op.", + ctx->GetInputsVarType("X").front()); + } ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out"); diff --git a/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4ae19d9c1e3bb2af3eb95650fbb5aabb8944a36 --- /dev/null +++ b/paddle/fluid/operators/get_tensor_from_selected_rows_op.cc @@ -0,0 +1,117 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor_util.h" + +namespace paddle { +namespace operators { + +class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "GetTensorFromSelectedRowsOp must has input X."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "GetTensorFromSelectedRowsOp must has output Out."); + PADDLE_ENFORCE( + ctx->GetInputsVarType("X").front() == + framework::proto::VarType::SELECTED_ROWS, + "The input X's type should be SelectedRows, but the received is %s", + ctx->Inputs("X").front(), ctx->GetInputsVarType("X").front()); + PADDLE_ENFORCE( + ctx->GetOutputsVarType("Out").front() == + framework::proto::VarType::LOD_TENSOR, + "The output Out's type should be LoDTensor, but the received is %s", + ctx->Outputs("Out").front(), ctx->GetOutputsVarType("Out").front()); + + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::GetDataTypeOfVar(ctx.InputVar("X")), ctx.device_context()); + } +}; + +class GetTensorFromSelectedRowsKernel { + public: + void operator()(const framework::ExecutionContext &ctx) const { + auto *x = ctx.Input("X"); + auto *out = ctx.Output("Out"); + + out->Resize(x->value().dims()); + out->mutable_data(ctx.GetPlace(), x->value().type()); + framework::TensorCopy(x->value(), ctx.GetPlace(), ctx.device_context(), + out); + } +}; + +class GetTensorFromSelectedRowsOpProtoMaker + : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input type is SelectedRows."); + AddOutput("Out", "The output type is LoDTensor."); + AddComment( + R"DOC( +GetTensorFromSelectedRows Operator + +GetTensorFromSelectedRows is used to get the tensor from SelectedRows. + +)DOC"); + } +}; + +class GetTensorFromSelectedRowsOpVarTypeInference + : public framework::VarTypeInference { + public: + void operator()(const framework::OpDesc &op_desc, + framework::BlockDesc *block) const final { + auto out_var_name = op_desc.Output("Out").front(); + auto in_var_name = op_desc.Input("X").front(); + + auto out_var = block->FindRecursiveOrCreateVar(out_var_name); + auto in_var = block->FindRecursiveOrCreateVar(in_var_name); + out_var.SetType(framework::proto::VarType::LOD_TENSOR); + out_var.SetDataType(in_var.GetDataType()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(get_tensor_from_selected_rows, + ops::GetTensorFromSelectedRowsOp, + ops::GetTensorFromSelectedRowsOpProtoMaker, + ops::GetTensorFromSelectedRowsOpVarTypeInference); + +REGISTER_OP_CPU_KERNEL_FUNCTOR(get_tensor_from_selected_rows, float, + ops::GetTensorFromSelectedRowsKernel, double, + ops::GetTensorFromSelectedRowsKernel, int, + ops::GetTensorFromSelectedRowsKernel, int64_t, + ops::GetTensorFromSelectedRowsKernel); + +#ifdef PADDLE_WITH_CUDA +REGISTER_OP_CUDA_KERNEL_FUNCTOR(get_tensor_from_selected_rows, float, + ops::GetTensorFromSelectedRowsKernel, double, + ops::GetTensorFromSelectedRowsKernel, int, + ops::GetTensorFromSelectedRowsKernel, int64_t, + ops::GetTensorFromSelectedRowsKernel); +#endif diff --git a/paddle/fluid/operators/merge_selected_rows_op.cc b/paddle/fluid/operators/merge_selected_rows_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c0d10c15997014fab864ebb916e6dafd96073e52 --- /dev/null +++ b/paddle/fluid/operators/merge_selected_rows_op.cc @@ -0,0 +1,93 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/merge_selected_rows_op.h" + +namespace paddle { +namespace operators { + +class MergeSelectedRowsOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of MergeSelectedRowsOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of MergeSelectedRowsOp should not be null."); + ctx->ShareDim("X", /*->*/ "Out"); + } +}; + +class MergeSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input type is SelectedRows, and the selected rows may be " + "duplicated."); + AddOutput("Out", + "The output type is SelectedRows, and the selected rows are not " + "duplicated."); + AddComment( + R"DOC( +MergeSelectedRows Operator. + +MergeSelectedRows is used to merge the duplicated rows of the input. The +output's row has no duplicated, and it is incremental in order. + +Example: + Input: + X.rows is [0, 5, 5, 4, 19] + X.height is 20 + X.value is: + [[1, 1] + [2, 2] + [3, 3] + [4, 4] + [6, 6]] + + Output: + Out.row is [0, 4, 5, 19] + Out.height is 20 + Out.value is: + [[1, 1] + [4, 4] + [5, 5] + [6, 6]] +)DOC"); + } +}; + +class MergeSelectedRowsOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Out"}}; + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OPERATOR(merge_selected_rows, ops::MergeSelectedRowsOp, + ops::MergeSelectedRowsOpMaker, + ops::MergeSelectedRowsOpInferVarType); + +REGISTER_OP_CPU_KERNEL( + merge_selected_rows, + ops::MergeSelectedRowsKernel, + ops::MergeSelectedRowsKernel); diff --git a/paddle/fluid/operators/merge_selected_rows_op.cu.cc b/paddle/fluid/operators/merge_selected_rows_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..90d5fb3eaeb1f155eeea29ea0cf3f5ecd610f5f0 --- /dev/null +++ b/paddle/fluid/operators/merge_selected_rows_op.cu.cc @@ -0,0 +1,23 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/merge_selected_rows_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + merge_selected_rows, + ops::MergeSelectedRowsKernel, + ops::MergeSelectedRowsKernel); diff --git a/paddle/fluid/operators/merge_selected_rows_op.h b/paddle/fluid/operators/merge_selected_rows_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4c977e94b175c988e4253b273365b0cabc4b87aa --- /dev/null +++ b/paddle/fluid/operators/merge_selected_rows_op.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" + +namespace paddle { +namespace operators { + +template +class MergeSelectedRowsKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + + math::scatter::MergeAdd merge_func; + merge_func(context.template device_context(), *x, out); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 1738afe93e99f1de28bec2fb23be8e1a309d9288..4315c4fc49d331bdc1b1ae74d13e202b17bf9995 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -271,7 +271,12 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): "All parameters' 'clip_norm' of a same group should be the same" ) - square = grad * grad + merge_grad = grad + if grad.type == core.VarDesc.VarType.SELECTED_ROWS: + merge_grad = layers.merge_selected_rows(grad) + merge_grad = layers.get_tensor_from_selected_rows(merge_grad) + + square = layers.square(merge_grad) local_norm_var = layers.reduce_sum(input=square) context[self.group_name].append(local_norm_var) @@ -292,6 +297,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): new_grad = layers.elementwise_mul( x=grad, y=self.context[group_scale_name]) + return param, new_grad diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 831a84b140ca733531dda7859fefe2ac9b35c2ce..d06769edbc7c1edf4d521a4aec03dddc35f3ca65 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -169,6 +169,8 @@ __all__ = [ 'log_loss', 'add_position_encoding', 'bilinear_tensor_product', + 'merge_selected_rows', + 'get_tensor_from_selected_rows', 'lstm', ] @@ -8291,6 +8293,29 @@ def mean(x, name=None): return out +@templatedoc() +def merge_selected_rows(x, name=None): + """ + ${comment} + + Args: + x(${x_type}): ${x_comment} + name(basestring|None): Name of the output. + + Returns: + out(${out_type}): ${out_comment} + """ + + helper = LayerHelper("merge_selected_rows", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type="merge_selected_rows", + inputs={"X": x}, + attrs={}, + outputs={"Out": out}) + return out + + @templatedoc() def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): """ @@ -8939,3 +8964,26 @@ def bilinear_tensor_product(x, # add activation return helper.append_activation(out) + + +@templatedoc() +def get_tensor_from_selected_rows(x, name=None): + """ + ${comment} + + Args: + x(${x_type}): ${x_comment} + name(basestring|None): Name of the output. + + Returns: + out(${out_type}): ${out_comment} + """ + + helper = LayerHelper('get_tensor_from_selected_rows', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='get_tensor_from_selected_rows', + inputs={'X': x}, + outputs={'Out': out}, + attrs={}) + return out diff --git a/python/paddle/fluid/tests/test_gradient_clip.py b/python/paddle/fluid/tests/test_gradient_clip.py deleted file mode 100644 index 266687fcd092dfdeec9343e2592f4c22b683d588..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/test_gradient_clip.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import numpy as np -import paddle -import paddle.fluid as fluid - -BATCH_SIZE = 128 -CLIP = 1 - -prog = fluid.framework.Program() -with fluid.program_guard(main_program=prog): - image = fluid.layers.data(name='x', shape=[784], dtype='float32') - - hidden1 = fluid.layers.fc(input=image, size=128, act='relu') - hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu') - predict = fluid.layers.fc(input=hidden2, size=10, act='softmax') - - label = fluid.layers.data(name='y', shape=[1], dtype='int64') - - cost = fluid.layers.cross_entropy(input=predict, label=label) - avg_cost = fluid.layers.mean(cost) - -prog_clip = prog.clone() - -avg_cost_clip = prog_clip.block(0).var(avg_cost.name) - -p_g = fluid.backward.append_backward(loss=avg_cost) -p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip) - -with fluid.program_guard(main_program=prog_clip): - fluid.clip.set_gradient_clip( - fluid.clip.GradientClipByGlobalNorm(clip_norm=CLIP)) - p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip) - -grad_list = [elem[1] for elem in p_g] -grad_clip_list = [elem[1] for elem in p_g_clip] - -train_reader = paddle.batch( - paddle.reader.shuffle( - paddle.dataset.mnist.train(), buf_size=8192), - batch_size=BATCH_SIZE) - -place = fluid.CPUPlace() -exe = fluid.Executor(place) -feeder = fluid.DataFeeder(feed_list=[image, label], place=place) -exe.run(fluid.default_startup_program()) - -count = 0 -for data in train_reader(): - count += 1 - if count > 5: - break - out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list) - out_clip = exe.run(prog_clip, - feed=feeder.feed(data), - fetch_list=grad_clip_list) - global_norm = 0 - for v in out[1:]: - global_norm += np.sum(np.power(v, 2)) - global_norm = np.sqrt(global_norm) - - global_norm_clip = 0 - for v in out_clip[1:]: - global_norm_clip += np.sum(np.power(v, 2)) - global_norm_clip = np.sqrt(global_norm_clip) - - if not np.isclose( - a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3): - exit(1) -exit(0) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 26035f303e72a87b81fdb120fbb92894d78e996b..65f80d368cdaeead742a86dd6bfb30cb8a6a9498 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -43,13 +43,14 @@ if(APPLE) list(REMOVE_ITEM TEST_OPS test_desc_clone) list(REMOVE_ITEM TEST_OPS test_program_code) endif(NOT WITH_DISTRIBUTE) - message(WARNING "These tests has been disabled in OSX before being fixed: \n test_fuse_elewise_add_act_pass \n test_detection_map_op \n test_dist_se_resnext") + message(WARNING "These tests has been disabled in OSX before being fixed: \n test_gradient_clip \n test_fuse_elewise_add_act_pass \n test_detection_map_op \n test_dist_se_resnext") # this op is not support on mac list(REMOVE_ITEM TEST_OPS test_fusion_seqexpand_concat_fc_op) # TODO: add the unitest back when it fixed list(REMOVE_ITEM TEST_OPS test_detection_map_op) list(REMOVE_ITEM TEST_OPS test_dist_se_resnext) list(REMOVE_ITEM TEST_OPS test_fuse_elewise_add_act_pass) + list(REMOVE_ITEM TEST_OPS test_gradient_clip) endif() if(NOT WITH_MKLML) # this op is not support on openblas diff --git a/python/paddle/fluid/tests/unittests/test_get_tensor_from_selected_rows_op.py b/python/paddle/fluid/tests/unittests/test_get_tensor_from_selected_rows_op.py new file mode 100644 index 0000000000000000000000000000000000000000..021b950b3b6245caecab22d476bbb9d6b6b45c5e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_get_tensor_from_selected_rows_op.py @@ -0,0 +1,65 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid.core as core +import numpy as np +from paddle.fluid.op import Operator + + +class TestGetTensorFromSelectedRows(unittest.TestCase): + def get_places(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + return places + + def check_with_place(self, place): + scope = core.Scope() + x_rows = [0, 5, 5, 4, 20] + height = 20 + row_numel = 2 + + np_array = np.ones((len(x_rows), row_numel)).astype("float32") + np_array[1, :] = 2.0 + np_array[2, :] = 3.0 + np_array[3, :] = 4.0 + + # initialize input variable X + x = scope.var('X').get_selected_rows() + x.set_rows(x_rows) + x.set_height(height) + x_tensor = x.get_tensor() + x_tensor.set(np_array, place) + + # initialize input variable Out + out = scope.var("Out").get_tensor() + + op = Operator("get_tensor_from_selected_rows", X="X", Out="Out") + + op.run(scope, place) + + out_array = np.array(out) + self.assertEqual((5, 2), out_array.shape) + assert (out_array == np_array).all() + + def test_check_output(self): + for place in self.get_places(): + self.check_with_place(place) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_gradient_clip.py b/python/paddle/fluid/tests/unittests/test_gradient_clip.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b3168ba6636253055f546fb3eec8a536714209 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_gradient_clip.py @@ -0,0 +1,162 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid + +BATCH_SIZE = 128 +CLIP = 1 + + +def bow_net(data, + label, + dict_dim, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2): + """ + BOW net + This model is from https://github.com/PaddlePaddle/models: + fluid/PaddleNLP/text_classification/nets.py + """ + emb = fluid.layers.embedding( + input=data, is_sparse=True, size=[dict_dim, emb_dim]) + bow = fluid.layers.sequence_pool(input=emb, pool_type='sum') + bow_tanh = fluid.layers.tanh(bow) + fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh") + fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh") + prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax") + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + + return avg_cost + + +class TestGradientClip(unittest.TestCase): + def setUp(self): + self.word_dict = paddle.dataset.imdb.word_dict() + self.BATCH_SIZE = 2 + self.train_data = paddle.batch( + paddle.dataset.imdb.train(self.word_dict), + batch_size=self.BATCH_SIZE) + + def get_places(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + return places + + def check_operators(self, place): + prog = fluid.framework.Program() + startup_program = fluid.framework.Program() + with fluid.program_guard( + main_program=prog, startup_program=startup_program): + image = fluid.layers.data(name='x', shape=[784], dtype='float32') + label = fluid.layers.data(name='y', shape=[1], dtype='int64') + + hidden1 = fluid.layers.fc(input=image, size=128, act='relu') + hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu') + predict = fluid.layers.fc(input=hidden2, size=10, act='softmax') + + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(cost) + + prog_clip = prog.clone() + + avg_cost_clip = prog_clip.block(0).var(avg_cost.name) + + p_g = fluid.backward.append_backward(loss=avg_cost) + p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip) + + with fluid.program_guard(main_program=prog_clip): + fluid.clip.set_gradient_clip( + fluid.clip.GradientClipByGlobalNorm(clip_norm=CLIP)) + p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip) + + grad_list = [elem[1] for elem in p_g] + grad_clip_list = [elem[1] for elem in p_g_clip] + + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=8192), + batch_size=BATCH_SIZE) + + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=[image, label], place=place) + exe.run(startup_program) + + count = 0 + for data in train_reader(): + count += 1 + if count > 5: + break + out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list) + out_clip = exe.run(prog_clip, + feed=feeder.feed(data), + fetch_list=grad_clip_list) + global_norm = 0 + for v in out[1:]: + global_norm += np.sum(np.power(v, 2)) + global_norm = np.sqrt(global_norm) + + global_norm_clip = 0 + for v in out_clip[1:]: + global_norm_clip += np.sum(np.power(v, 2)) + global_norm_clip = np.sqrt(global_norm_clip) + + assert np.isclose( + a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3) + + def check_sparse_gradient_clip(self, place): + prog = fluid.framework.Program() + startup_program = fluid.framework.Program() + with fluid.program_guard( + main_program=prog, startup_program=startup_program): + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + cost = bow_net(data, label, len(self.word_dict)) + + fluid.clip.set_gradient_clip( + clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0)) + + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01) + sgd_optimizer.minimize(cost) + + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=[data, label], place=place) + exe.run(startup_program) + + data = next(self.train_data()) + val = exe.run(prog, feed=feeder.feed(data), fetch_list=[cost])[0] + self.assertEqual((1, ), val.shape) + print(val) + self.assertFalse(np.isnan(val)) + + def test_operators(self): + self.check_operators(core.CPUPlace()) + + def test_sparse_gradient_clip(self): + for place in self.get_places(): + self.check_sparse_gradient_clip(place) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_merge_selectedrows_op.py b/python/paddle/fluid/tests/unittests/test_merge_selectedrows_op.py new file mode 100644 index 0000000000000000000000000000000000000000..d2fa344b67ab33a93f92733efd68e896c767bad2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_merge_selectedrows_op.py @@ -0,0 +1,73 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid.core as core +import numpy as np +from paddle.fluid.op import Operator + + +class TestMergeSelectedRows(unittest.TestCase): + def get_places(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(core.CUDAPlace(0)) + return places + + def check_with_place(self, place): + scope = core.Scope() + x_rows = [0, 5, 5, 4, 19] + out_rows = [0, 4, 5, 19] + height = 20 + row_numel = 2 + + np_array = np.ones((len(x_rows), row_numel)).astype("float32") + np_array[1, :] = 2.0 + np_array[2, :] = 3.0 + np_array[3, :] = 4.0 + + # initialize input variable X + x = scope.var('X').get_selected_rows() + x.set_rows(x_rows) + x.set_height(height) + x_tensor = x.get_tensor() + x_tensor.set(np_array, place) + + # initialize input variable Out + out = scope.var("Out").get_selected_rows() + + op = Operator("merge_selected_rows", X="X", Out="Out") + + op.run(scope, place) + + self.assertEqual(out.rows(), out_rows) + self.assertEqual(out.height(), height) + + out_array = np.array(out.get_tensor()) + self.assertEqual((4, 2), out_array.shape) + + assert (out_array[0, :] == 1.0).all() + assert (out_array[1, :] == 4.0).all() + assert (out_array[2, :] == 5.0).all() + assert (out_array[3, :] == 1.0).all() + + def test_check_output(self): + for place in self.get_places(): + self.check_with_place(place) + + +if __name__ == "__main__": + unittest.main()