diff --git a/paddle/fluid/operators/activation_cudnn.cu.cc b/paddle/fluid/operators/activation_cudnn.cu.cc index 494c02374a9faa22486644c9b9c7d586c86d41b0..7f8ecc1df0734241bdd3c220786c3578a9e5193d 100644 --- a/paddle/fluid/operators/activation_cudnn.cu.cc +++ b/paddle/fluid/operators/activation_cudnn.cu.cc @@ -31,8 +31,8 @@ class CudnnActivationKernel ExtractActivationTensor(context, X, Out); ActivationDescriptor act_desc; TensorDescriptor x_desc, out_desc; - x_desc.set(detail::Ref(X)); - out_desc.set(detail::Ref(Out)); + x_desc.set(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivation")); + out_desc.set(GET_DATA_SAFELY(Out, "Output", "Out", "CudnnActivation"); } }; diff --git a/paddle/fluid/operators/activation_cudnn_op.cu.cc b/paddle/fluid/operators/activation_cudnn_op.cu.cc index 84767f7c2240f4d324d924253656d406b3297f04..33d8fb828f86e1f689ad4f67ad0033b45ce2671e 100644 --- a/paddle/fluid/operators/activation_cudnn_op.cu.cc +++ b/paddle/fluid/operators/activation_cudnn_op.cu.cc @@ -37,7 +37,7 @@ struct CudnnActivationFunctor { act_desc.set(mode_, coef_); TensorDescriptor x_desc, out_desc; x_desc.set(x); - out_desc.set(detail::Ref(out)); + out_desc.set(GET_DATA_SAFELY(out, "Output", "Out", "CudnnActivation")); PADDLE_ENFORCE(platform::dynload::cudnnActivationForward( ctx_.cudnn_handle(), act_desc.desc(), platform::CudnnDataType::kOne(), x_desc.desc(), x.data(), @@ -63,7 +63,7 @@ struct CudnnActivationGradFunctor { x_desc.set(x); out_desc.set(out); dout_desc.set(dout); - dx_desc.set(detail::Ref(dx)); + dx_desc.set(GET_DATA_SAFELY(dx, "Output", "X@GRAD", "CudnnActivationGrad")); PADDLE_ENFORCE(platform::dynload::cudnnActivationBackward( ctx_.cudnn_handle(), act_desc.desc(), platform::CudnnDataType::kOne(), out_desc.desc(), out.data(), @@ -141,7 +141,7 @@ class CudnnActivationKernel Out->mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); Functor functor(dev_ctx); - functor(detail::Ref(X), Out); + functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivation"), Out); } }; @@ -161,7 +161,10 @@ class CudnnActivationGradKernel dX->mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); Functor functor(dev_ctx); - functor(detail::Ref(X), detail::Ref(Out), detail::Ref(dOut), dX); + functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivationGrad"), + GET_DATA_SAFELY(Out, "Input", "Out", "CudnnActivationGrad"), + GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "CudnnActivationGrad"), + dX); } }; diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 8194b1ef44caed3e8a66e79c67cf9efbe7c4c3d9..30bb989bbf038888342de08dbe97fe1e40f38da8 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/float16.h" @@ -156,8 +155,10 @@ class ActivationKernel ExtractActivationTensor(context, &X, &Out); Out->mutable_data(context.GetPlace()); - auto x = framework::EigenVector::Flatten(detail::Ref(X)); - auto out = framework::EigenVector::Flatten(detail::Ref(Out)); + auto x = framework::EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "Activation")); + auto out = framework::EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Output", "Out", "Activation")); auto* place = context.template device_context().eigen_device(); Functor functor; @@ -182,10 +183,14 @@ class ActivationGradKernel ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX); dX->mutable_data(context.GetPlace()); - auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); - auto out = framework::EigenVector::Flatten(detail::Ref(Out)); - auto dx = framework::EigenVector::Flatten(detail::Ref(dX)); - auto x = framework::EigenVector::Flatten(detail::Ref(X)); + auto dout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad")); + auto out = framework::EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad")); + auto dx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad")); + auto x = framework::EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad")); auto* place = context.template device_context().eigen_device(); Functor functor; @@ -1285,10 +1290,13 @@ struct ReluGradGradFunctor : public BaseActivationFunctor { framework::Tensor* ddOut, framework::Tensor* dOut, framework::Tensor* dX) const { auto* d = dev.eigen_device(); - auto ddx = framework::EigenVector::Flatten(detail::Ref(ddX)); - auto out = framework::EigenVector::Flatten(detail::Ref(Out)); + auto ddx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "ReluGradGrad")); + auto out = framework::EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Output", "Out", "ReluGradGrad")); if (ddOut) { - auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); + auto ddout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ReluGradGrad")); ddout.device(*d) = ddx * (out > static_cast(0)).template cast(); } } @@ -1308,9 +1316,12 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor { framework::Tensor* dX) const { if (ddOut) { auto* d = dev.eigen_device(); - auto ddx = framework::EigenVector::Flatten(detail::Ref(ddX)); - auto out = framework::EigenVector::Flatten(detail::Ref(Out)); - auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); + auto ddx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "LeakyReluGradGrad")); + auto out = framework::EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Output", "Out", "LeakyReluGradGrad")); + auto ddout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DOut", "LeakyReluGradGrad")); ddout.device(*d) = ddx * ((out > static_cast(0)).template cast() + static_cast(alpha) * @@ -1332,18 +1343,23 @@ struct ELUGradGradFunctor : public BaseActivationFunctor { const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* dOut, framework::Tensor* dX) const { auto* d = dev.eigen_device(); - auto ddx = framework::EigenVector::Flatten(detail::Ref(ddX)); - auto x = framework::EigenVector::Flatten(detail::Ref(X)); + auto ddx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "ELUGradGrad")); + auto x = framework::EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "ELUGradGrad")); if (dX) { - auto dx = framework::EigenVector::Flatten(detail::Ref(dX)); - auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); + auto dx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "ELUGradGrad")); + auto dout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad")); dx.device(*d) = ddx * dout * static_cast(alpha) * x.exp() * (x < static_cast(0)).template cast(); } if (ddOut) { - auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); + auto ddout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ELUGradGrad")); ddout.device(*d) = ddx * ((x > static_cast(0)).template cast() + static_cast(alpha) * x.exp() * @@ -1361,17 +1377,22 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor { const framework::Tensor* ddX, framework::Tensor* ddOut, framework::Tensor* dOut, const framework::Tensor* dX) const { auto* d = dev.eigen_device(); - auto ddx = framework::EigenVector::Flatten(detail::Ref(ddX)); - auto out = framework::EigenVector::Flatten(detail::Ref(Out)); + auto ddx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad")); + auto out = framework::EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad")); // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx // calculate dy first, so ddy can inplace ddx if (dOut) { - auto dx = framework::EigenVector::Flatten(detail::Ref(dX)); - auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); + auto dx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad")); + auto dout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad")); dout.device(*d) = dx * ddx * static_cast(-1) / out; } if (ddOut) { - auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); + auto ddout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad")); ddout.device(*d) = ddx * static_cast(0.5) / out; } } @@ -1385,17 +1406,22 @@ struct SquareGradGradFunctor : public BaseActivationFunctor { const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* dOut, framework::Tensor* dX) const { auto* d = dev.eigen_device(); - auto ddx = framework::EigenVector::Flatten(detail::Ref(ddX)); - auto x = framework::EigenVector::Flatten(detail::Ref(X)); + auto ddx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "SquareGradGrad")); + auto x = framework::EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "SquareGradGrad")); // square GradGrad: ddy=2x*ddx, dx=2dy*ddx // calculate dx first, so ddy can inplace ddx if (dX) { - auto dx = framework::EigenVector::Flatten(detail::Ref(dX)); - auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); + auto dx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "SquareGradGrad")); + auto dout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "SquareGradGrad")); dx.device(*d) = ddx * static_cast(2) * dout; } if (ddOut) { - auto ddout = framework::EigenVector::Flatten(detail::Ref(ddOut)); + auto ddout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad")); ddout.device(*d) = ddx * static_cast(2) * x; } } @@ -1557,8 +1583,10 @@ class PowKernel : public framework::OpKernel { ExtractActivationTensor(context, &X, &Out); Out->mutable_data(context.GetPlace()); - auto x = framework::EigenVector::Flatten(detail::Ref(X)); - auto out = framework::EigenVector::Flatten(detail::Ref(Out)); + auto x = framework::EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "Pow")); + auto out = framework::EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Output", "Out", "Pow")); auto* place = context.template device_context().eigen_device(); Functor functor; @@ -1602,10 +1630,14 @@ class PowGradKernel ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX); dX->mutable_data(context.GetPlace()); - auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); - auto out = framework::EigenVector::Flatten(detail::Ref(Out)); - auto dx = framework::EigenVector::Flatten(detail::Ref(dX)); - auto x = framework::EigenVector::Flatten(detail::Ref(X)); + auto dout = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "PowGrad")); + auto out = framework::EigenVector::Flatten( + GET_DATA_SAFELY(Out, "Input", "Out", "PowGrad")); + auto dx = framework::EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "X@GRAD", "PowGrad")); + auto x = framework::EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "PowGrad")); auto* place = context.template device_context().eigen_device(); Functor functor; diff --git a/paddle/fluid/operators/add_position_encoding_op.h b/paddle/fluid/operators/add_position_encoding_op.h index 0b40d3de890a02a9dbec2328f9f6388ffa35561b..30d54e5cf98e31cb5f7523d4a2886d0d6aefc89b 100644 --- a/paddle/fluid/operators/add_position_encoding_op.h +++ b/paddle/fluid/operators/add_position_encoding_op.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/controlflow/get_places_op.cc b/paddle/fluid/operators/controlflow/get_places_op.cc index 6bbe0cefbd6089feabd4797c554239c9cbe83fa8..eff88f54ade6e4bc71e8d80771b3f757819354a9 100644 --- a/paddle/fluid/operators/controlflow/get_places_op.cc +++ b/paddle/fluid/operators/controlflow/get_places_op.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include // NOLINT #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/platform/place.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/gpu_info.h" @@ -56,10 +55,9 @@ class GetPlacesOp : public framework::OperatorBase { is_gpu ? "GPU" : "CPU"); auto out_var_name = Output("Out"); - auto &places = - *(detail::Ref(scope.FindVar(out_var_name), - "Output variable %s cannot be found", out_var_name) - .GetMutable()); + auto &places = *(GET_DATA_SAFELY(scope.FindVar(out_var_name), "Output", + "Out", "GetPlaces") + .GetMutable()); places.reserve(device_count); if (is_gpu) { PADDLE_ENFORCE_LE(device_count, CUDADevCount(), diff --git a/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc b/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc index 504a72fafe6e83d9b2f90c58db1c1c7c9f06e6ee..fe9dacb532314e234f26d656f476e75f1603a021 100644 --- a/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc +++ b/paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/array_operator.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/math_function.h" namespace paddle { diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 33a3f88d24e2bc5979ab04b61b2176208e176e2c..f6aaa49eceda0aacc1f76b235672cdb75ceba3b8 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -19,7 +19,6 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h" -#include "paddle/fluid/operators/detail/safe_ref.h" namespace paddle { namespace operators { @@ -198,23 +197,18 @@ class WhileGradOp : public framework::OperatorBase { continue; } - auto &og_outside = - detail::Ref(scope.FindVar(outside_og_name), - "Cannot find Outside Gradient %s", outside_og_name); - auto &og_inside = - detail::Ref(cur_scope.Var(inside_og_name), - "Cannot find inside gradient %s", inside_og_name); + auto &og_outside = *scope.FindVar(outside_og_name); + auto &og_inside = *cur_scope.Var(inside_og_name); if (og_outside.IsType()) { auto &outside_tensor = og_outside.Get(); - auto &inside_tensor = - detail::Ref(og_inside.GetMutable()); + auto &inside_tensor = *og_inside.GetMutable(); inside_tensor.set_lod(outside_tensor.lod()); inside_tensor.ShareDataWith(outside_tensor); } else if (og_outside.IsType()) { auto outside_array = og_outside.GetMutable(); auto &inside_array = - detail::Ref(og_inside.GetMutable()); + *og_inside.GetMutable(); inside_array.clear(); inside_array.resize(outside_array->size()); VLOG(8) << outside_og_name << " size = " << outside_array->size(); diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index 2480916d21eab6c904c571887797ead0e0e3debc..ec9adafaa0dd35a8264741946be7256248596048 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -20,7 +20,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/im2col.h" @@ -674,9 +673,8 @@ class GemmConvDoubleGradKernel : public framework::OpKernel { Tensor* ddY = ctx.Output("DDOutput"); Tensor* dW = ctx.Output("DFilter"); Tensor* dX = ctx.Output("DInput"); - Tensor W = detail::Ref(ctx.Input("Filter"), - "Cannot find input Filter(%s) in scope)", - ctx.InputNames("Filter")[0]); + Tensor W = GET_DATA_SAFELY(ctx.Input("Filter"), "Input", "Filter", + "GemmConvDoubleGrad"); if (!ddY && !dW && !dX) return; const int groups = ctx.Attr("groups"); diff --git a/paddle/fluid/operators/cum_op.h b/paddle/fluid/operators/cum_op.h index d158bd4dfe55eaeb23718a6596c4f80d018ac7b1..3e975420e3ef1d7dd6e442ed093da902e9d88251 100644 --- a/paddle/fluid/operators/cum_op.h +++ b/paddle/fluid/operators/cum_op.h @@ -18,7 +18,6 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/operators/detail/safe_ref.h" namespace paddle { namespace operators { @@ -29,13 +28,11 @@ class CumKernel : public framework::OpKernel { using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { - auto& X = detail::Ref(context.Input("X"), - "Cannot get input tensor X, variable name = %s", - context.InputName("X")); + auto& X = GET_DATA_SAFELY(context.Input("X"), "Input", + "X", "Cum"); - auto& Out = detail::Ref(context.Output("Out"), - "Cannot get output tensor Out, variable name = %s", - context.OutputName("Out")); + auto& Out = GET_DATA_SAFELY(context.Output("Out"), + "Output", "Out", "Cum"); int axis = context.Attr("axis"); bool exclusive = context.Attr("exclusive"); bool reverse = context.Attr("reverse"); @@ -46,7 +43,7 @@ class CumKernel : public framework::OpKernel { PADDLE_ENFORCE_LT( axis, x_dims.size(), "axis should be less than the dimensiotn of the input tensor"); - Out.mutable_data(context.GetPlace()); + Out.template mutable_data(context.GetPlace()); int pre = 1; int post = 1; diff --git a/paddle/fluid/operators/detail/safe_ref.h b/paddle/fluid/operators/detail/safe_ref.h deleted file mode 100644 index c56329d9ee5ab73c6a683c9ea0955e27bdc65564..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/detail/safe_ref.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { -namespace detail { -/** - * Get Reference From Pointer with check. The error message is printf format, - * and passed by `args` - */ -template -inline T& Ref(T* ptr, ARGS&&... args) { - PADDLE_ENFORCE_NOT_NULL(ptr, ::paddle::string::Sprintf(args...)); - return *ptr; -} - -template -inline std::vector> VectorRef( - const std::vector& vec, ARGS&&... args) { - std::vector> result; - result.reserve(vec.size()); - for (auto* ptr : vec) { - result.emplace_back(Ref(ptr, args...)); - } - return result; -} - -} // namespace detail -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.h b/paddle/fluid/operators/detection/collect_fpn_proposals_op.h index c125d55ecca42edf7b5a604fdd2a1bb5bcad87c4..98bba5343cfa7b5c1d4b58ce2ec3b76d2a5168a4 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.h +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.h @@ -20,7 +20,6 @@ limitations under the License.*/ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/math/math_function.h" diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h index 536ffe25a5600ab407a74af072519b36ae8ffe76..09fb4a382d02717476cb866ba9868e51db28ae7c 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.h @@ -20,7 +20,6 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/math/math_function.h" diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cc b/paddle/fluid/operators/detection/generate_proposals_op.cc index 6d2312d3b6aa378c891c8f26d78d2362666dcf2d..2ab094cd8b88702bc27826cfaeb2cd9ba73f8ba3 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cc +++ b/paddle/fluid/operators/detection/generate_proposals_op.cc @@ -17,7 +17,6 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/math/math_function.h" @@ -293,12 +292,10 @@ class GenerateProposalsKernel : public framework::OpKernel { auto *scores = context.Input("Scores"); auto *bbox_deltas = context.Input("BboxDeltas"); auto *im_info = context.Input("ImInfo"); - auto anchors = detail::Ref(context.Input("Anchors"), - "Cannot find input Anchors(%s) in scope", - context.InputNames("Anchors")[0]); - auto variances = detail::Ref(context.Input("Variances"), - "Cannot find input Variances(%s) in scope", - context.InputNames("Variances")[0]); + auto anchors = GET_DATA_SAFELY(context.Input("Anchors"), "Input", + "Anchors", "GenerateProposals"); + auto variances = GET_DATA_SAFELY(context.Input("Variances"), + "Input", "Variances", "GenerateProposals"); auto *rpn_rois = context.Output("RpnRois"); auto *rpn_roi_probs = context.Output("RpnRoiProbs"); diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cu b/paddle/fluid/operators/detection/generate_proposals_op.cu index 9a25e205bc81b8f58ec77c283895d1c51b30927d..10e111d6673cf43041c4267bebc089f834ddb99c 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cu +++ b/paddle/fluid/operators/detection/generate_proposals_op.cu @@ -20,7 +20,6 @@ limitations under the License. */ #include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/for_range.h" @@ -367,12 +366,10 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel { auto *scores = context.Input("Scores"); auto *bbox_deltas = context.Input("BboxDeltas"); auto *im_info = context.Input("ImInfo"); - auto anchors = detail::Ref(context.Input("Anchors"), - "Cannot find input Anchors(%s) in scope", - context.InputNames("Anchors")[0]); - auto variances = detail::Ref(context.Input("Variances"), - "Cannot find input Variances(%s) in scope", - context.InputNames("Variances")[0]); + auto anchors = GET_DATA_SAFELY(context.Input("Anchors"), "Input", + "Anchors", "GenerateProposals"); + auto variances = GET_DATA_SAFELY(context.Input("Variances"), + "Input", "Variances", "GenerateProposals"); auto *rpn_rois = context.Output("RpnRois"); auto *rpn_roi_probs = context.Output("RpnRoiProbs"); diff --git a/paddle/fluid/operators/fill_op.h b/paddle/fluid/operators/fill_op.h index 99700736c1b53aeb1595622df1931539d482c215..02e388bcb40f8d4ade56fa3f37be15ed0320e7c5 100644 --- a/paddle/fluid/operators/fill_op.h +++ b/paddle/fluid/operators/fill_op.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" namespace paddle { namespace operators { @@ -44,10 +43,8 @@ template class FillKernel : public framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext &ctx) const override { - auto &out = - detail::Ref(ctx.Output("Out"), - "Cannot get output lod tensor Out, variable name = %s", - ctx.OutputName("Out")); + auto &out = GET_DATA_SAFELY(ctx.Output("Out"), + "Output", "Out", "Fill"); out.Resize(framework::make_ddim(ctx.Attr>("shape"))); auto dtype = static_cast(ctx.Attr("dtype")); diff --git a/paddle/fluid/operators/fused/fused_elemwise_activation_op.h b/paddle/fluid/operators/fused/fused_elemwise_activation_op.h index 62a6175e33fe063bd6e5efdd5e123b745770c1fb..2c0c5f9ec0afa49396540d454aceb5d2df1542cf 100644 --- a/paddle/fluid/operators/fused/fused_elemwise_activation_op.h +++ b/paddle/fluid/operators/fused/fused_elemwise_activation_op.h @@ -18,7 +18,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/compound_functors.h" #include "paddle/fluid/operators/math/functors.h" @@ -383,12 +382,10 @@ template class FusedElemwiseActivationKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - auto &in_x = detail::Ref(ctx.Input("X"), - "Cannot get input tensor %s, variable name = %s", - "X", ctx.InputName("X")); - auto &in_y = detail::Ref(ctx.Input("Y"), - "Cannot get input tensor %s, variable name = %s", - "Y", ctx.InputName("Y")); + auto &in_x = GET_DATA_SAFELY(ctx.Input("X"), "Input", + "X", "FusedElemwiseActivation"); + auto &in_y = GET_DATA_SAFELY(ctx.Input("Y"), "Input", + "Y", "FusedElemwiseActivation"); PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty"); auto output = ctx.Output("Out"); diff --git a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc index 81037cb3149fd334e6b681e7fac76e9571582a74..b53b407d4995da5d548a13fec20ff3b09a5583c4 100644 --- a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc +++ b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/platform/errors.h" namespace paddle { diff --git a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu index 0461cedbd7efe9083febc06494c5ae65be3baae0..d8bd5d03a7d17786cafc2ca865fc72d5f2eca9a3 100644 --- a/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu +++ b/paddle/fluid/operators/fused/fused_embedding_eltwise_layernorm_op.cu @@ -19,7 +19,6 @@ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h" #include "paddle/fluid/operators/math/blas.h" diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cc b/paddle/fluid/operators/fused/multihead_matmul_op.cc index ccc90ae368f8a0a8fb976d7595b0bffff0da3292..ad8db4c62ec6e446b8322cb3711cc8340b90b8c1 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cc +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/platform/errors.h" namespace paddle { diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index d9ab092bbe5fa93be00ef4b0900b7402d0fa8b4e..fb5ce3468538a3edecc89ecc558e5146506795bd 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -17,7 +17,6 @@ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h" #include "paddle/fluid/operators/math/blas.h" @@ -142,14 +141,13 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel { auto *input = context.Input("Input"); auto *w = context.Input("W"); auto *bias = context.Input("Bias"); - - auto &bias_qk = detail::Ref(context.Input("BiasQK"), - "Cannot find QK"); + auto &bias_qk = GET_DATA_SAFELY(context.Input("BiasQK"), + "Input", "BiasQK", "MultiHeadMatMulV2"); auto *input_d = input->data(); auto *w_d = w->data(); auto *bias_d = bias->data(); - auto *bias_qk_d = bias_qk.data(); + auto *bias_qk_d = bias_qk.template data(); T scale = static_cast(context.Attr("alpha")); int head_number = context.Attr("head_number"); diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index 752773670433006e40b899f88bf7f078fa958464..e873d909da1c1da5f696759c567770b674b871df 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/clip_op.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/fluid/platform/transform.h" @@ -40,8 +39,9 @@ template using EigenMatrix = framework::EigenMatrix; using platform::Transform; +using framework::LoDTensor; -static std::vector PathToRows(const framework::LoDTensor& path) { +static std::vector PathToRows(const LoDTensor& path) { std::set rows; const int64_t* paths = path.data(); for (int64_t i = 0; i < path.numel(); ++i) { @@ -57,14 +57,17 @@ template class HierarchicalSigmoidOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto& in = detail::Ref(ctx.Input("X")); - auto& w = detail::Ref(ctx.Input("W")); - auto* path = ctx.Input("PathTable"); - auto* code = ctx.Input("PathCode"); - auto& label = detail::Ref(ctx.Input("Label")); - auto* bias = ctx.Input("Bias"); - auto* out = ctx.Output("Out"); - auto* pre_out = ctx.Output("PreOut"); + auto& in = GET_DATA_SAFELY(ctx.Input("X"), "Input", "X", + "HierarchicalSigmoid"); + auto& w = GET_DATA_SAFELY(ctx.Input("W"), "Input", "W", + "HierarchicalSigmoid"); + auto* path = ctx.Input("PathTable"); + auto* code = ctx.Input("PathCode"); + auto& label = GET_DATA_SAFELY(ctx.Input("Label"), "Input", + "Label", "HierarchicalSigmoid"); + auto* bias = ctx.Input("Bias"); + auto* out = ctx.Output("Out"); + auto* pre_out = ctx.Output("PreOut"); size_t num_classes = static_cast(ctx.Attr("num_classes")); // for remote prefetch @@ -75,7 +78,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { int64_t code_length = path ? path->dims()[1] : math::FindLastSet(num_classes - 1); int64_t batch_size = in.dims()[0]; - framework::LoDTensor sum; + LoDTensor sum; auto& dev_ctx = ctx.template device_context(); auto* pre_out_data = pre_out->mutable_data( framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); @@ -89,11 +92,11 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { std::unique_ptr> bit_code; if (!is_custom) { - bit_code.reset(new math::MatrixBitCodeFunctor(num_classes, - label.data())); + bit_code.reset(new math::MatrixBitCodeFunctor( + num_classes, label.template data())); } else { - bit_code.reset(new math::MatrixBitCodeFunctor(*path, *code, - label.data())); + bit_code.reset(new math::MatrixBitCodeFunctor( + *path, *code, label.template data())); } std::vector sum_dims({batch_size, 1UL}); @@ -126,20 +129,24 @@ template class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto& in = detail::Ref(ctx.Input("X")); - auto& w = detail::Ref(ctx.Input("W")); - auto* path = ctx.Input("PathTable"); - auto* code = ctx.Input("PathCode"); - auto* in_grad = - ctx.Output(framework::GradVarName("X")); + auto& in = GET_DATA_SAFELY(ctx.Input("X"), "Input", "X", + "HierarchicalSigmoidGrad"); + auto& w = GET_DATA_SAFELY(ctx.Input("W"), "Input", "W", + "HierarchicalSigmoidGrad"); + auto* path = ctx.Input("PathTable"); + auto* code = ctx.Input("PathCode"); + auto* in_grad = ctx.Output(framework::GradVarName("X")); bool is_sparse = ctx.Attr("is_sparse"); auto& dev_ctx = ctx.template device_context(); math::SetConstant zero; - auto& label = detail::Ref(ctx.Input("Label")); - auto& pre_out = detail::Ref(ctx.Input("PreOut")); - auto& out_grad = detail::Ref( - ctx.Input(framework::GradVarName("Out"))); - framework::LoDTensor pre_out_grad; + auto& label = GET_DATA_SAFELY(ctx.Input("Label"), "Input", + "Label", "HierarchicalSigmoidGrad"); + auto& pre_out = GET_DATA_SAFELY(ctx.Input("PreOut"), "Input", + "PreOut", "HierarchicalSigmoidGrad"); + auto& out_grad = GET_DATA_SAFELY( + ctx.Input(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "HierarchicalSigmoidGrad"); + LoDTensor pre_out_grad; pre_out_grad.mutable_data(pre_out.dims(), ctx.GetPlace()); in_grad->mutable_data(ctx.GetPlace()); @@ -154,11 +161,11 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { std::unique_ptr> bit_code; if (!is_custom) { - bit_code.reset(new math::MatrixBitCodeFunctor(num_classes, - label.data())); + bit_code.reset(new math::MatrixBitCodeFunctor( + num_classes, label.template data())); } else { - bit_code.reset(new math::MatrixBitCodeFunctor(*path, *code, - label.data())); + bit_code.reset(new math::MatrixBitCodeFunctor( + *path, *code, label.template data())); } // softrelu derivative @@ -166,7 +173,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { auto blas = math::GetBlas(ctx); auto* pre_out_grad_data = pre_out_grad.data(); - auto* pre_out_data = pre_out.data(); + auto* pre_out_data = pre_out.template data(); auto n = pre_out.numel(); blas.VEXP(n, pre_out_data, pre_out_grad_data); blas.VINV(n, pre_out_grad_data, pre_out_grad_data); @@ -174,7 +181,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i]; } bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b) - auto* out_grad_data = out_grad.data(); + auto* out_grad_data = out_grad.template data(); int64_t dim0 = pre_out_grad.dims()[0]; int64_t dim1 = pre_out_grad.dims()[1]; @@ -184,16 +191,14 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel { } // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to // be consistent with the clipping in forward. - auto* bias_grad = - ctx.Output(framework::GradVarName("Bias")); + auto* bias_grad = ctx.Output(framework::GradVarName("Bias")); if (bias_grad) { bias_grad->mutable_data(ctx.GetPlace()); zero(dev_ctx, bias_grad, static_cast(0.0)); bit_code->AddGrad(pre_out_grad, bias_grad); } if (!is_sparse) { - auto* w_grad = - ctx.Output(framework::GradVarName("W")); + auto* w_grad = ctx.Output(framework::GradVarName("W")); w_grad->mutable_data(ctx.GetPlace()); zero(dev_ctx, w_grad, static_cast(0.0)); bit_code->MulGradWeight(pre_out_grad, w_grad, in); diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index ca8a4a09e707a29a0a505dba961594f6b3aec526..50a2a2c9467fb15d44631e3257ed74bf2c0334dc 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/port.h" @@ -95,13 +94,15 @@ class LoDTensorToArrayOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input %s", - Input("X")) + auto &x = GET_DATA_SAFELY(scope.FindVar(Input("X")), "Input", "X", + "LoDTensorToArray") .Get(); - auto &rank_table = detail::Ref(scope.FindVar(Input("RankTable"))) + auto &rank_table = GET_DATA_SAFELY(scope.FindVar(Input("RankTable")), + "Input", "RankTable", "LoDTensorToArray") .Get(); - auto &out = *detail::Ref(scope.FindVar(Output("Out"))) - .GetMutable(); + auto &out = *(GET_DATA_SAFELY(scope.FindVar(Output("Out")), "Output", "Out", + "LoDTensorToArray") + .GetMutable()); auto &items = rank_table.items(); auto max_seq_len = items[0].length; auto rank_level = rank_table.level(); diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index af9e0644c9269a234a6a01153ef1b194ca63cbb0..69609fa5bcdeb24e73e0ea6608a16445388477e4 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -16,7 +16,6 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { @@ -58,10 +57,10 @@ template class MatMulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - auto &x = - detail::Ref(context.Input("X"), "Cannot find X"); - auto &y = - detail::Ref(context.Input("Y"), "Cannot find Y"); + auto &x = GET_DATA_SAFELY(context.Input("X"), "Input", + "X", "MatMul"); + auto &y = GET_DATA_SAFELY(context.Input("Y"), "Input", + "Y", "MatMul"); auto *out = context.Output("Out"); out->mutable_data(context.GetPlace()); diff --git a/paddle/fluid/operators/mkldnn/mkldnn_activation_op.h b/paddle/fluid/operators/mkldnn/mkldnn_activation_op.h index aa34a092bd6bbaa3b0700b1cab11a69b2f2beff4..6c294a9518653ed6de6b8699cfc44c4539661fde 100644 --- a/paddle/fluid/operators/mkldnn/mkldnn_activation_op.h +++ b/paddle/fluid/operators/mkldnn/mkldnn_activation_op.h @@ -17,7 +17,6 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" diff --git a/paddle/fluid/operators/optimizers/adam_op.cu b/paddle/fluid/operators/optimizers/adam_op.cu index b130ffe6464ddae9e81f5eb416aad3677751c4a4..fbab8cf063b2c39b779d939534e5202379df37bb 100644 --- a/paddle/fluid/operators/optimizers/adam_op.cu +++ b/paddle/fluid/operators/optimizers/adam_op.cu @@ -128,7 +128,6 @@ class AdamOpCUDAKernel : public framework::OpKernel { framework::ToTypeName(param_var->Type()))); using paddle::framework::LoDTensor; - using paddle::operators::detail::Ref; int64_t min_row_size_to_use_multithread = ctx.Attr("min_row_size_to_use_multithread"); diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 461c94976ffc84d7b5db2b9893f834fc66893829..11452480227bbd628e384e722aadac381f7be045 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -20,7 +20,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/threadpool.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/for_range.h" @@ -384,7 +383,6 @@ class AdamOpKernel : public framework::OpKernel { framework::ToTypeName(param_var->Type())); using paddle::framework::LoDTensor; - using paddle::operators::detail::Ref; int64_t min_row_size_to_use_multithread = ctx.Attr("min_row_size_to_use_multithread"); diff --git a/paddle/fluid/operators/optimizers/lamb_op.h b/paddle/fluid/operators/optimizers/lamb_op.h index e6d518a4f731f806d7a4271d58d24ae1dcca11c3..b1e37e2b217504129c5086d7686449ce13893ea3 100644 --- a/paddle/fluid/operators/optimizers/lamb_op.h +++ b/paddle/fluid/operators/optimizers/lamb_op.h @@ -17,7 +17,6 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/for_range.h" @@ -185,30 +184,32 @@ class LambOpKernel : public framework::OpKernel { framework::ToTypeName(param_var->Type())); using paddle::framework::LoDTensor; - using paddle::operators::detail::Ref; T weight_decay = static_cast(ctx.Attr("weight_decay")); T beta1 = static_cast(ctx.Attr("beta1")); T beta2 = static_cast(ctx.Attr("beta2")); T epsilon = static_cast(ctx.Attr("epsilon")); - auto& param = Ref(ctx.Input("Param"), "Must set Param."); + auto& param = GET_DATA_SAFELY(ctx.Input("Param"), "Input", + "Param", "Lamb"); auto* grad_var = ctx.InputVar("Grad"); - auto& mom1 = Ref(ctx.Input("Moment1"), "Must set Moment1."); - auto& mom2 = Ref(ctx.Input("Moment2"), "Must set Moment2."); - auto& lr = - Ref(ctx.Input("LearningRate"), "Must set LearningRate."); - - auto& beta1_pow = - Ref(ctx.Input("Beta1Pow"), "Must set Beta1Pow."); - auto& beta2_pow = - Ref(ctx.Input("Beta2Pow"), "Must set Beta2Pow."); - - auto& param_out = - Ref(ctx.Output("ParamOut"), "Must set ParamOut."); - auto& mom1_out = - Ref(ctx.Output("Moment1Out"), "Must set Moment1Out."); - auto& mom2_out = - Ref(ctx.Output("Moment2Out"), "Must set Moment1Out."); + auto& mom1 = GET_DATA_SAFELY(ctx.Input("Moment1"), "Input", + "Moment1", "Lamb"); + auto& mom2 = GET_DATA_SAFELY(ctx.Input("Moment2"), "Input", + "Moment2", "Lamb"); + auto& lr = GET_DATA_SAFELY(ctx.Input("LearningRate"), "Input", + "LearningRate", "Lamb"); + + auto& beta1_pow = GET_DATA_SAFELY(ctx.Input("Beta1Pow"), "Input", + "Beta1Pow", "Lamb"); + auto& beta2_pow = GET_DATA_SAFELY(ctx.Input("Beta2Pow"), "Input", + "Beta2Pow", "Lamb"); + + auto& param_out = GET_DATA_SAFELY(ctx.Output("ParamOut"), + "Output", "ParamOut", "Lamb"); + auto& mom1_out = GET_DATA_SAFELY(ctx.Output("Moment1Out"), + "Output", "Moment1Out", "Lamb"); + auto& mom2_out = GET_DATA_SAFELY(ctx.Output("Moment2Out"), + "Output", "Moment2Out", "Lamb"); auto& dev_ctx = ctx.template device_context(); platform::ForRange for_range(dev_ctx, param.numel()); @@ -217,7 +218,7 @@ class LambOpKernel : public framework::OpKernel { // Update moments if (grad_var->IsType()) { - auto& grad = Ref(ctx.Input("Grad"), "Must set Grad."); + auto& grad = *ctx.Input("Grad"); LambMomentUpdateFunctor moment_update_functor( weight_decay, beta1, beta2, epsilon, beta1_pow.template data(), @@ -229,8 +230,8 @@ class LambOpKernel : public framework::OpKernel { trust_ratio_div.template data()); for_range(moment_update_functor); } else if (grad_var->IsType()) { - auto& grad = - Ref(ctx.Input("Grad"), "Must set Grad."); + auto& grad = GET_DATA_SAFELY(ctx.Input("Grad"), + "Input", "Grad", "Lamb"); if (grad.rows().size() == 0) { VLOG(3) << "grad row size is 0!!"; return; diff --git a/paddle/fluid/operators/random_crop_op.h b/paddle/fluid/operators/random_crop_op.h index ae58358cbb202c229cb8a96e20e4c48d926c5bf3..c84f886c80fe0aef7fdf0c5bb9c7d562ae032b41 100644 --- a/paddle/fluid/operators/random_crop_op.h +++ b/paddle/fluid/operators/random_crop_op.h @@ -16,7 +16,6 @@ #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/for_range.h" #ifdef PADDLE_WITH_CUDA @@ -152,10 +151,11 @@ class RandomCropKernel : public framework::OpKernel { public: virtual void Compute(const framework::ExecutionContext& ctx) const { int64_t seed = 0; - auto& seed_tensor = detail::Ref(ctx.Input("Seed")); + auto& seed_tensor = GET_DATA_SAFELY(ctx.Input("Seed"), + "Input", "Seed", "RandomCrop"); if (seed_tensor.IsInitialized()) { if (platform::is_cpu_place(seed_tensor.place())) { - seed = *seed_tensor.data(); + seed = *seed_tensor.template data(); } else { LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify " "your program"; @@ -169,13 +169,15 @@ class RandomCropKernel : public framework::OpKernel { seed = ctx.Attr("startup_seed"); } auto shape = ctx.Attr>("shape"); - auto& x = detail::Ref(ctx.Input("X")); - auto& out = detail::Ref(ctx.Output("Out")); + auto& x = GET_DATA_SAFELY(ctx.Input("X"), "Input", + "X", "RandomCrop"); + auto& out = GET_DATA_SAFELY(ctx.Output("Out"), + "Output", "Out", "RandomCrop"); int num_batchsize_dims = x.dims().size() - shape.size(); RandomCropFunctor functor( - x.data(), out.mutable_data(ctx.GetPlace()), x.dims(), out.dims(), - num_batchsize_dims, seed); + x.template data(), out.template mutable_data(ctx.GetPlace()), + x.dims(), out.dims(), num_batchsize_dims, seed); platform::ForRange for_range( ctx.template device_context(), functor.prod_batchsize_dims_); diff --git a/paddle/fluid/operators/reader/create_custom_reader_op.cc b/paddle/fluid/operators/reader/create_custom_reader_op.cc index a7d815367f33c4cdb28c94c03fe013916a3cf45b..bf44a2d53ee7e04b12363d6e15c2fe62ebb1a47b 100644 --- a/paddle/fluid/operators/reader/create_custom_reader_op.cc +++ b/paddle/fluid/operators/reader/create_custom_reader_op.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/framework/executor.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/reader/reader_op_registry.h" namespace paddle { @@ -171,8 +170,11 @@ void CustomReader::ReadNextImpl(std::vector* out) { // 3. Copy LoDTensors from sink variables to out. out->resize(sink_var_names_.size()); for (size_t i = 0; i < sink_var_names_.size(); ++i) { - const auto& tensor = detail::Ref(exe_scope->FindVar(sink_var_names_[i])) - .Get(); + auto* var = exe_scope->FindVar(sink_var_names_[i]); + PADDLE_ENFORCE_NOT_NULL(var, platform::errors::NotFound( + "The variable %s is not in current scope.", + sink_var_names_[i])); + const auto& tensor = var->Get(); framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]); } scope_.DeleteScope(exe_scope); diff --git a/paddle/fluid/operators/reader/read_op.cc b/paddle/fluid/operators/reader/read_op.cc index 9a4885056cdba5f837ccc9fd2e1fd443394f5f59..f23c858bb637d6f78a9ae9ca135c4ab50c3bfb86 100644 --- a/paddle/fluid/operators/reader/read_op.cc +++ b/paddle/fluid/operators/reader/read_op.cc @@ -15,7 +15,6 @@ #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { @@ -96,8 +95,8 @@ class ReadOp : public framework::OperatorBase { const platform::Place& dev_place) const override { VLOG(3) << "read op in"; framework::ReaderHolder* reader = - detail::Ref(scope.FindVar(Input("Reader")), - "Cannot find reader variable %s", Input("Reader")) + GET_DATA_SAFELY(scope.FindVar(Input("Reader")), "Input", "Reader", + "Read") .GetMutable(); std::vector out_arg_names = Outputs("Out"); std::vector ins; diff --git a/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc b/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc index ce9842c9955291115de91d3bf2dc212dbc6cc0e4..1a8d2e584cad32747b3826614b356804cf2ff22b 100644 --- a/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc +++ b/paddle/fluid/operators/reorder_lod_tensor_by_rank_op.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { @@ -78,18 +77,16 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { private: void RunImpl(const framework::Scope &scope, const platform::Place &place) const override { - auto &x = - detail::Ref(scope.FindVar(Input("X")), - "Cannot find input lod tensor variable %s", Input("X")) - .Get(); - auto &rank_table = detail::Ref(scope.FindVar(Input("RankTable")), - "Cannot find input rank table variable %s", - Input("RankTable")) - .Get(); - auto &out = - *detail::Ref(scope.FindVar(Output("Out")), - "Cannot find output lod tensor variable %s", Output("Out")) - .GetMutable(); + auto &x = GET_DATA_SAFELY(scope.FindVar(Input("X")), "Input", "X", + "ReorderLoDTensorByRankTable") + .Get(); + auto &rank_table = + GET_DATA_SAFELY(scope.FindVar(Input("RankTable")), "Input", "RankTable", + "ReorderLoDTensorByRankTable") + .Get(); + auto &out = *(GET_DATA_SAFELY(scope.FindVar(Output("Out")), "Output", "Out", + "ReorderLoDTensorByRankTable") + .GetMutable()); out.Resize(x.dims()); out.mutable_data(x.place(), x.type()); diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index b25674c26a179ee784e4c5b643fb08046745b5fb..e67348e39989295fd6a46b94d7a7203736b4212e 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -17,8 +17,6 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/detail/safe_ref.h" - namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sequence_ops/sequence_concat_op.h b/paddle/fluid/operators/sequence_ops/sequence_concat_op.h index dc0acba122198815e25929bc0efb7257808e3e52..9c5cc5c80316980caa29668a834fd09e1c70bcbd 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_concat_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_concat_op.h @@ -18,7 +18,6 @@ #include #include "boost/optional.hpp" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/math/concat_and_split.h" namespace paddle { @@ -47,16 +46,28 @@ inline framework::LoD ConcatLoD(const Container &xs, lod.emplace_back(result); return lod; } + +template +inline std::vector> GetDataVectorSafely( + const std::vector &vec, ARGS &&... args) { + std::vector> result; + result.reserve(vec.size()); + for (auto *ptr : vec) { + PADDLE_ENFORCE_NOT_NULL(ptr, platform::errors::InvalidArgument( + "The input variable X contains nullptr.")); + result.emplace_back(*ptr); + } + return result; +} } // namespace detail template class SeqConcatKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - auto xs = detail::VectorRef(context.MultiInput("X"), - "Cannot find multiple input X"); - auto &out = detail::Ref(context.Output("Out"), - "Cannot find output"); + auto xs = detail::GetDataVectorSafely( + context.MultiInput("X")); + auto &out = *context.Output("Out"); size_t lod_size = 0; for (auto &x : xs) { @@ -141,9 +152,9 @@ class SeqConcatGradKernel : public framework::OpKernel { math::SplitFunctor functor; functor(context.template device_context(), - detail::Ref( + GET_DATA_SAFELY( context.Input(framework::GradVarName("Out")), - "Sequence Concat OG must be set"), + "Input", framework::GradVarName("Out"), "SeqConcatGrad"), sliced_x_ptr, 0, &sliced_dx_ptr); } }; diff --git a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.h b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.h index b67488200c317680e1eca0786c8ef1313c3d0f9c..b215c894273b0a4c52d4984b05584a7cc0dadae0 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.h +++ b/paddle/fluid/operators/sequence_ops/sequence_expand_as_op.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include "glog/logging.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/detail/safe_ref.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index a854a11f64c4ab6a5ac68dabe5772c3b042e910a..8340a6191ccff1bf3cce384745c6045cdf9a8d8d 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -18,7 +18,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/var_type_inference.h" -#include "paddle/fluid/operators/detail/safe_ref.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 45aa32e17c32285a1113e9068c9c996b70b7cc22..ccb859b24d7aea92f519c9fd7761181a978853e4 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -350,6 +350,51 @@ struct EnforceNotMet : public std::exception { /** EXTENDED TOOL FUNCTIONS WITH CHECKING **/ +/* + * Summary: This macro is used to get Variable or internal type + * data (such as LoDTensor or SelectedRows) of the Input and + * Output in op, generally used when call scope.FindVar(Input/ + * Output("Name")) or ctx.Input(). + * Firstly this macro check whether the obtained pointer is null, + * and then return data if it is not null. + * + * Note: This macro is only suitable for specific scenarios and + * does not intended to be widely used. If it cannot meet the + * requirements, please use other PADDLE_ENFORCE** check macro. + * + * Parameters: + *     __PTR: pointer + * __ROLE: (string), Input or Output + * __NAME: (string), Input or Output name + * __OP_TYPE: (string), the op type + *   + * Return: The data pointed to by the pointer. + * + * Examples: + * GET_DATA_SAFELY(ctx.Input("X"), "Input", "X", "Mul"); +*/ +#define GET_DATA_SAFELY(__PTR, __ROLE, __NAME, __OP_TYPE) \ + (([&]() -> std::add_lvalue_reference::type { \ + auto* ptr = (__PTR); \ + if (UNLIKELY(nullptr == ptr)) { \ + __THROW_ERROR_INTERNAL__( \ + "%s\n [Hint: pointer " #__PTR " should not be null.]", \ + paddle::platform::errors::NotFound( \ + "Unable to get %s data of %s %s in operator %s. " \ + "Possible reasons are:\n" \ + " 1. The %s is not the %s of operator %s;\n" \ + " 2. The %s has no corresponding variable passed in;\n" \ + " 3. The %s corresponding variable is not initialized.", \ + paddle::platform::demangle( \ + typeid(std::add_lvalue_reference::type) \ + .name()), \ + __ROLE, __NAME, __OP_TYPE, __NAME, __ROLE, __OP_TYPE, __NAME, \ + __NAME) \ + .ToString()); \ + } \ + return *ptr; \ + })()) + /* * Summary: This macro is used to check whether op has specified * Input or Output Variables. Because op's Input and Output diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index 1215005ad80de9119533e714aa447cf874690dde..0057c784528c2654fbe58aa7e48c91e27f9843de 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -362,6 +362,22 @@ TEST(enforce, cannot_to_string_type) { PADDLE_ENFORCE_NE(list.begin(), list.end()); } +TEST(GET_DATA_SAFELY_MACRO, SUCCESS) { + int* a = new int(10); + GET_DATA_SAFELY(a, "Input", "X", "dummy"); +} + +TEST(GET_DATA_SAFELY_MACRO, FAIL) { + bool caught_exception = false; + try { + int* a = nullptr; + GET_DATA_SAFELY(a, "Input", "X", "dummy"); + } catch (paddle::platform::EnforceNotMet& error) { + caught_exception = true; + } + EXPECT_TRUE(caught_exception); +} + TEST(OP_INOUT_CHECK_MACRO, SUCCESS) { OP_INOUT_CHECK(true, "Input", "X", "dummy"); }