未验证 提交 16315d3d 编写于 作者: C Chen Weihang 提交者: GitHub

Delete Ref & VectorRef and add GetDataSafely (#22997)

* delete invalid check inferface Ref & VectorRef, test=develop

* fix vector ref delete error, test=develop

* try the new check inferface, test=develop

* change all related code with new check macro, test=develop

* remove static assert, test=develop

* polish detail, test=develop

* skip coverage problem, test=develop

* add new check macro, test=develop
上级 4c675a45
...@@ -31,8 +31,8 @@ class CudnnActivationKernel ...@@ -31,8 +31,8 @@ class CudnnActivationKernel
ExtractActivationTensor(context, X, Out); ExtractActivationTensor(context, X, Out);
ActivationDescriptor act_desc; ActivationDescriptor act_desc;
TensorDescriptor x_desc, out_desc; TensorDescriptor x_desc, out_desc;
x_desc.set(detail::Ref(X)); x_desc.set(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivation"));
out_desc.set(detail::Ref(Out)); out_desc.set(GET_DATA_SAFELY(Out, "Output", "Out", "CudnnActivation");
} }
}; };
......
...@@ -37,7 +37,7 @@ struct CudnnActivationFunctor { ...@@ -37,7 +37,7 @@ struct CudnnActivationFunctor {
act_desc.set(mode_, coef_); act_desc.set(mode_, coef_);
TensorDescriptor x_desc, out_desc; TensorDescriptor x_desc, out_desc;
x_desc.set(x); x_desc.set(x);
out_desc.set(detail::Ref(out)); out_desc.set(GET_DATA_SAFELY(out, "Output", "Out", "CudnnActivation"));
PADDLE_ENFORCE(platform::dynload::cudnnActivationForward( PADDLE_ENFORCE(platform::dynload::cudnnActivationForward(
ctx_.cudnn_handle(), act_desc.desc(), ctx_.cudnn_handle(), act_desc.desc(),
platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(), platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(),
...@@ -63,7 +63,7 @@ struct CudnnActivationGradFunctor { ...@@ -63,7 +63,7 @@ struct CudnnActivationGradFunctor {
x_desc.set(x); x_desc.set(x);
out_desc.set(out); out_desc.set(out);
dout_desc.set(dout); dout_desc.set(dout);
dx_desc.set(detail::Ref(dx)); dx_desc.set(GET_DATA_SAFELY(dx, "Output", "X@GRAD", "CudnnActivationGrad"));
PADDLE_ENFORCE(platform::dynload::cudnnActivationBackward( PADDLE_ENFORCE(platform::dynload::cudnnActivationBackward(
ctx_.cudnn_handle(), act_desc.desc(), ctx_.cudnn_handle(), act_desc.desc(),
platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(), platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(),
...@@ -141,7 +141,7 @@ class CudnnActivationKernel ...@@ -141,7 +141,7 @@ class CudnnActivationKernel
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<CUDADeviceContext>(); auto& dev_ctx = context.template device_context<CUDADeviceContext>();
Functor functor(dev_ctx); Functor functor(dev_ctx);
functor(detail::Ref(X), Out); functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivation"), Out);
} }
}; };
...@@ -161,7 +161,10 @@ class CudnnActivationGradKernel ...@@ -161,7 +161,10 @@ class CudnnActivationGradKernel
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<CUDADeviceContext>(); auto& dev_ctx = context.template device_context<CUDADeviceContext>();
Functor functor(dev_ctx); Functor functor(dev_ctx);
functor(detail::Ref(X), detail::Ref(Out), detail::Ref(dOut), dX); functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivationGrad"),
GET_DATA_SAFELY(Out, "Input", "Out", "CudnnActivationGrad"),
GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "CudnnActivationGrad"),
dX);
} }
}; };
......
...@@ -26,7 +26,6 @@ limitations under the License. */ ...@@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -156,8 +155,10 @@ class ActivationKernel ...@@ -156,8 +155,10 @@ class ActivationKernel
ExtractActivationTensor(context, &X, &Out); ExtractActivationTensor(context, &X, &Out);
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X)); auto x = framework::EigenVector<T>::Flatten(
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); GET_DATA_SAFELY(X, "Input", "X", "Activation"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "Activation"));
auto* place = auto* place =
context.template device_context<DeviceContext>().eigen_device(); context.template device_context<DeviceContext>().eigen_device();
Functor functor; Functor functor;
...@@ -182,10 +183,14 @@ class ActivationGradKernel ...@@ -182,10 +183,14 @@ class ActivationGradKernel
ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut, ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
&dX); &dX);
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut)); auto dout = framework::EigenVector<T>::Flatten(
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad"));
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX)); auto out = framework::EigenVector<T>::Flatten(
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X)); GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad"));
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad"));
auto* place = auto* place =
context.template device_context<DeviceContext>().eigen_device(); context.template device_context<DeviceContext>().eigen_device();
Functor functor; Functor functor;
...@@ -1285,10 +1290,13 @@ struct ReluGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -1285,10 +1290,13 @@ struct ReluGradGradFunctor : public BaseActivationFunctor<T> {
framework::Tensor* ddOut, framework::Tensor* dOut, framework::Tensor* ddOut, framework::Tensor* dOut,
framework::Tensor* dX) const { framework::Tensor* dX) const {
auto* d = dev.eigen_device(); auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX)); auto ddx = framework::EigenVector<T>::Flatten(
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); GET_DATA_SAFELY(ddX, "Input", "DDX", "ReluGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "ReluGradGrad"));
if (ddOut) { if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut)); auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ReluGradGrad"));
ddout.device(*d) = ddx * (out > static_cast<T>(0)).template cast<T>(); ddout.device(*d) = ddx * (out > static_cast<T>(0)).template cast<T>();
} }
} }
...@@ -1308,9 +1316,12 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -1308,9 +1316,12 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
framework::Tensor* dX) const { framework::Tensor* dX) const {
if (ddOut) { if (ddOut) {
auto* d = dev.eigen_device(); auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX)); auto ddx = framework::EigenVector<T>::Flatten(
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); GET_DATA_SAFELY(ddX, "Input", "DDX", "LeakyReluGradGrad"));
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut)); auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "LeakyReluGradGrad"));
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DOut", "LeakyReluGradGrad"));
ddout.device(*d) = ddx * ddout.device(*d) = ddx *
((out > static_cast<T>(0)).template cast<T>() + ((out > static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) * static_cast<T>(alpha) *
...@@ -1332,18 +1343,23 @@ struct ELUGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -1332,18 +1343,23 @@ struct ELUGradGradFunctor : public BaseActivationFunctor<T> {
const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const { const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device(); auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX)); auto ddx = framework::EigenVector<T>::Flatten(
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X)); GET_DATA_SAFELY(ddX, "Input", "DDX", "ELUGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "ELUGradGrad"));
if (dX) { if (dX) {
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX)); auto dx = framework::EigenVector<T>::Flatten(
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut)); GET_DATA_SAFELY(dX, "Output", "DX", "ELUGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad"));
dx.device(*d) = ddx * dout * static_cast<T>(alpha) * x.exp() * dx.device(*d) = ddx * dout * static_cast<T>(alpha) * x.exp() *
(x < static_cast<T>(0)).template cast<T>(); (x < static_cast<T>(0)).template cast<T>();
} }
if (ddOut) { if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut)); auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ELUGradGrad"));
ddout.device(*d) = ddx * ddout.device(*d) = ddx *
((x > static_cast<T>(0)).template cast<T>() + ((x > static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) * x.exp() * static_cast<T>(alpha) * x.exp() *
...@@ -1361,17 +1377,22 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -1361,17 +1377,22 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* ddX, framework::Tensor* ddOut,
framework::Tensor* dOut, const framework::Tensor* dX) const { framework::Tensor* dOut, const framework::Tensor* dX) const {
auto* d = dev.eigen_device(); auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX)); auto ddx = framework::EigenVector<T>::Flatten(
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad"));
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
// calculate dy first, so ddy can inplace ddx // calculate dy first, so ddy can inplace ddx
if (dOut) { if (dOut) {
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX)); auto dx = framework::EigenVector<T>::Flatten(
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut)); GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad"));
dout.device(*d) = dx * ddx * static_cast<T>(-1) / out; dout.device(*d) = dx * ddx * static_cast<T>(-1) / out;
} }
if (ddOut) { if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut)); auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(0.5) / out; ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
} }
} }
...@@ -1385,17 +1406,22 @@ struct SquareGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -1385,17 +1406,22 @@ struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const { const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device(); auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX)); auto ddx = framework::EigenVector<T>::Flatten(
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X)); GET_DATA_SAFELY(ddX, "Input", "DDX", "SquareGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "SquareGradGrad"));
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx // square GradGrad: ddy=2x*ddx, dx=2dy*ddx
// calculate dx first, so ddy can inplace ddx // calculate dx first, so ddy can inplace ddx
if (dX) { if (dX) {
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX)); auto dx = framework::EigenVector<T>::Flatten(
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut)); GET_DATA_SAFELY(dX, "Output", "DX", "SquareGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SquareGradGrad"));
dx.device(*d) = ddx * static_cast<T>(2) * dout; dx.device(*d) = ddx * static_cast<T>(2) * dout;
} }
if (ddOut) { if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut)); auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(2) * x; ddout.device(*d) = ddx * static_cast<T>(2) * x;
} }
} }
...@@ -1557,8 +1583,10 @@ class PowKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> { ...@@ -1557,8 +1583,10 @@ class PowKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
ExtractActivationTensor(context, &X, &Out); ExtractActivationTensor(context, &X, &Out);
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X)); auto x = framework::EigenVector<T>::Flatten(
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); GET_DATA_SAFELY(X, "Input", "X", "Pow"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "Pow"));
auto* place = auto* place =
context.template device_context<DeviceContext>().eigen_device(); context.template device_context<DeviceContext>().eigen_device();
Functor functor; Functor functor;
...@@ -1602,10 +1630,14 @@ class PowGradKernel ...@@ -1602,10 +1630,14 @@ class PowGradKernel
ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut, ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
&dX); &dX);
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut)); auto dout = framework::EigenVector<T>::Flatten(
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "PowGrad"));
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX)); auto out = framework::EigenVector<T>::Flatten(
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X)); GET_DATA_SAFELY(Out, "Input", "Out", "PowGrad"));
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "X@GRAD", "PowGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "PowGrad"));
auto* place = auto* place =
context.template device_context<DeviceContext>().eigen_device(); context.template device_context<DeviceContext>().eigen_device();
Functor functor; Functor functor;
......
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
...@@ -56,9 +55,8 @@ class GetPlacesOp : public framework::OperatorBase { ...@@ -56,9 +55,8 @@ class GetPlacesOp : public framework::OperatorBase {
is_gpu ? "GPU" : "CPU"); is_gpu ? "GPU" : "CPU");
auto out_var_name = Output("Out"); auto out_var_name = Output("Out");
auto &places = auto &places = *(GET_DATA_SAFELY(scope.FindVar(out_var_name), "Output",
*(detail::Ref(scope.FindVar(out_var_name), "Out", "GetPlaces")
"Output variable %s cannot be found", out_var_name)
.GetMutable<platform::PlaceList>()); .GetMutable<platform::PlaceList>());
places.reserve(device_count); places.reserve(device_count);
if (is_gpu) { if (is_gpu) {
......
...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/array_operator.h" #include "paddle/fluid/operators/array_operator.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -198,23 +197,18 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -198,23 +197,18 @@ class WhileGradOp : public framework::OperatorBase {
continue; continue;
} }
auto &og_outside = auto &og_outside = *scope.FindVar(outside_og_name);
detail::Ref(scope.FindVar(outside_og_name), auto &og_inside = *cur_scope.Var(inside_og_name);
"Cannot find Outside Gradient %s", outside_og_name);
auto &og_inside =
detail::Ref(cur_scope.Var(inside_og_name),
"Cannot find inside gradient %s", inside_og_name);
if (og_outside.IsType<framework::LoDTensor>()) { if (og_outside.IsType<framework::LoDTensor>()) {
auto &outside_tensor = og_outside.Get<framework::LoDTensor>(); auto &outside_tensor = og_outside.Get<framework::LoDTensor>();
auto &inside_tensor = auto &inside_tensor = *og_inside.GetMutable<framework::LoDTensor>();
detail::Ref(og_inside.GetMutable<framework::LoDTensor>());
inside_tensor.set_lod(outside_tensor.lod()); inside_tensor.set_lod(outside_tensor.lod());
inside_tensor.ShareDataWith(outside_tensor); inside_tensor.ShareDataWith(outside_tensor);
} else if (og_outside.IsType<framework::LoDTensorArray>()) { } else if (og_outside.IsType<framework::LoDTensorArray>()) {
auto outside_array = auto outside_array =
og_outside.GetMutable<framework::LoDTensorArray>(); og_outside.GetMutable<framework::LoDTensorArray>();
auto &inside_array = auto &inside_array =
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>()); *og_inside.GetMutable<framework::LoDTensorArray>();
inside_array.clear(); inside_array.clear();
inside_array.resize(outside_array->size()); inside_array.resize(outside_array->size());
VLOG(8) << outside_og_name << " size = " << outside_array->size(); VLOG(8) << outside_og_name << " size = " << outside_array->size();
......
...@@ -20,7 +20,6 @@ limitations under the License. */ ...@@ -20,7 +20,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
...@@ -674,9 +673,8 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> { ...@@ -674,9 +673,8 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
Tensor* ddY = ctx.Output<Tensor>("DDOutput"); Tensor* ddY = ctx.Output<Tensor>("DDOutput");
Tensor* dW = ctx.Output<Tensor>("DFilter"); Tensor* dW = ctx.Output<Tensor>("DFilter");
Tensor* dX = ctx.Output<Tensor>("DInput"); Tensor* dX = ctx.Output<Tensor>("DInput");
Tensor W = detail::Ref(ctx.Input<Tensor>("Filter"), Tensor W = GET_DATA_SAFELY(ctx.Input<Tensor>("Filter"), "Input", "Filter",
"Cannot find input Filter(%s) in scope)", "GemmConvDoubleGrad");
ctx.InputNames("Filter")[0]);
if (!ddY && !dW && !dX) return; if (!ddY && !dW && !dX) return;
const int groups = ctx.Attr<int>("groups"); const int groups = ctx.Attr<int>("groups");
......
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -29,13 +28,11 @@ class CumKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> { ...@@ -29,13 +28,11 @@ class CumKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
using T = typename Functor::ELEMENT_TYPE; using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto& X = detail::Ref(context.Input<framework::Tensor>("X"), auto& X = GET_DATA_SAFELY(context.Input<framework::Tensor>("X"), "Input",
"Cannot get input tensor X, variable name = %s", "X", "Cum");
context.InputName("X"));
auto& Out = detail::Ref(context.Output<framework::Tensor>("Out"), auto& Out = GET_DATA_SAFELY(context.Output<framework::Tensor>("Out"),
"Cannot get output tensor Out, variable name = %s", "Output", "Out", "Cum");
context.OutputName("Out"));
int axis = context.Attr<int>("axis"); int axis = context.Attr<int>("axis");
bool exclusive = context.Attr<bool>("exclusive"); bool exclusive = context.Attr<bool>("exclusive");
bool reverse = context.Attr<bool>("reverse"); bool reverse = context.Attr<bool>("reverse");
...@@ -46,7 +43,7 @@ class CumKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> { ...@@ -46,7 +43,7 @@ class CumKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
axis, x_dims.size(), axis, x_dims.size(),
"axis should be less than the dimensiotn of the input tensor"); "axis should be less than the dimensiotn of the input tensor");
Out.mutable_data<T>(context.GetPlace()); Out.template mutable_data<T>(context.GetPlace());
int pre = 1; int pre = 1;
int post = 1; int post = 1;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace detail {
/**
* Get Reference From Pointer with check. The error message is printf format,
* and passed by `args`
*/
template <typename T, typename... ARGS>
inline T& Ref(T* ptr, ARGS&&... args) {
PADDLE_ENFORCE_NOT_NULL(ptr, ::paddle::string::Sprintf(args...));
return *ptr;
}
template <typename T, typename... ARGS>
inline std::vector<std::reference_wrapper<T>> VectorRef(
const std::vector<T*>& vec, ARGS&&... args) {
std::vector<std::reference_wrapper<T>> result;
result.reserve(vec.size());
for (auto* ptr : vec) {
result.emplace_back(Ref(ptr, args...));
}
return result;
}
} // namespace detail
} // namespace operators
} // namespace paddle
...@@ -20,7 +20,6 @@ limitations under the License.*/ ...@@ -20,7 +20,6 @@ limitations under the License.*/
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
......
...@@ -20,7 +20,6 @@ limitations under the License. */ ...@@ -20,7 +20,6 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
......
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -293,12 +292,10 @@ class GenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -293,12 +292,10 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
auto *scores = context.Input<Tensor>("Scores"); auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas"); auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_info = context.Input<Tensor>("ImInfo"); auto *im_info = context.Input<Tensor>("ImInfo");
auto anchors = detail::Ref(context.Input<Tensor>("Anchors"), auto anchors = GET_DATA_SAFELY(context.Input<Tensor>("Anchors"), "Input",
"Cannot find input Anchors(%s) in scope", "Anchors", "GenerateProposals");
context.InputNames("Anchors")[0]); auto variances = GET_DATA_SAFELY(context.Input<Tensor>("Variances"),
auto variances = detail::Ref(context.Input<Tensor>("Variances"), "Input", "Variances", "GenerateProposals");
"Cannot find input Variances(%s) in scope",
context.InputNames("Variances")[0]);
auto *rpn_rois = context.Output<LoDTensor>("RpnRois"); auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs"); auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
......
...@@ -20,7 +20,6 @@ limitations under the License. */ ...@@ -20,7 +20,6 @@ limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
...@@ -367,12 +366,10 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> { ...@@ -367,12 +366,10 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
auto *scores = context.Input<Tensor>("Scores"); auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas"); auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_info = context.Input<Tensor>("ImInfo"); auto *im_info = context.Input<Tensor>("ImInfo");
auto anchors = detail::Ref(context.Input<Tensor>("Anchors"), auto anchors = GET_DATA_SAFELY(context.Input<Tensor>("Anchors"), "Input",
"Cannot find input Anchors(%s) in scope", "Anchors", "GenerateProposals");
context.InputNames("Anchors")[0]); auto variances = GET_DATA_SAFELY(context.Input<Tensor>("Variances"),
auto variances = detail::Ref(context.Input<Tensor>("Variances"), "Input", "Variances", "GenerateProposals");
"Cannot find input Variances(%s) in scope",
context.InputNames("Variances")[0]);
auto *rpn_rois = context.Output<LoDTensor>("RpnRois"); auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs"); auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");
......
...@@ -19,7 +19,6 @@ limitations under the License. */ ...@@ -19,7 +19,6 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -44,10 +43,8 @@ template <typename T> ...@@ -44,10 +43,8 @@ template <typename T>
class FillKernel : public framework::OpKernel<T> { class FillKernel : public framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override { void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto &out = auto &out = GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Out"),
detail::Ref(ctx.Output<framework::LoDTensor>("Out"), "Output", "Out", "Fill");
"Cannot get output lod tensor Out, variable name = %s",
ctx.OutputName("Out"));
out.Resize(framework::make_ddim(ctx.Attr<std::vector<int>>("shape"))); out.Resize(framework::make_ddim(ctx.Attr<std::vector<int>>("shape")));
auto dtype = auto dtype =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")); static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
......
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/compound_functors.h" #include "paddle/fluid/operators/math/compound_functors.h"
#include "paddle/fluid/operators/math/functors.h" #include "paddle/fluid/operators/math/functors.h"
...@@ -383,12 +382,10 @@ template <typename DeviceContext, typename T> ...@@ -383,12 +382,10 @@ template <typename DeviceContext, typename T>
class FusedElemwiseActivationKernel : public framework::OpKernel<T> { class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto &in_x = detail::Ref(ctx.Input<framework::Tensor>("X"), auto &in_x = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("X"), "Input",
"Cannot get input tensor %s, variable name = %s", "X", "FusedElemwiseActivation");
"X", ctx.InputName("X")); auto &in_y = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("Y"), "Input",
auto &in_y = detail::Ref(ctx.Input<framework::Tensor>("Y"), "Y", "FusedElemwiseActivation");
"Cannot get input tensor %s, variable name = %s",
"Y", ctx.InputName("Y"));
PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty"); PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty");
auto output = ctx.Output<framework::Tensor>("Out"); auto output = ctx.Output<framework::Tensor>("Out");
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
namespace paddle { namespace paddle {
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
namespace paddle { namespace paddle {
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
...@@ -142,14 +141,13 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { ...@@ -142,14 +141,13 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
auto *input = context.Input<framework::Tensor>("Input"); auto *input = context.Input<framework::Tensor>("Input");
auto *w = context.Input<framework::Tensor>("W"); auto *w = context.Input<framework::Tensor>("W");
auto *bias = context.Input<framework::Tensor>("Bias"); auto *bias = context.Input<framework::Tensor>("Bias");
auto &bias_qk = GET_DATA_SAFELY(context.Input<framework::Tensor>("BiasQK"),
auto &bias_qk = detail::Ref(context.Input<framework::Tensor>("BiasQK"), "Input", "BiasQK", "MultiHeadMatMulV2");
"Cannot find QK");
auto *input_d = input->data<T>(); auto *input_d = input->data<T>();
auto *w_d = w->data<T>(); auto *w_d = w->data<T>();
auto *bias_d = bias->data<T>(); auto *bias_d = bias->data<T>();
auto *bias_qk_d = bias_qk.data<T>(); auto *bias_qk_d = bias_qk.template data<T>();
T scale = static_cast<T>(context.Attr<float>("alpha")); T scale = static_cast<T>(context.Attr<float>("alpha"));
int head_number = context.Attr<int>("head_number"); int head_number = context.Attr<int>("head_number");
......
...@@ -24,7 +24,6 @@ limitations under the License. */ ...@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
...@@ -40,8 +39,9 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -40,8 +39,9 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
using platform::Transform; using platform::Transform;
using framework::LoDTensor;
static std::vector<int64_t> PathToRows(const framework::LoDTensor& path) { static std::vector<int64_t> PathToRows(const LoDTensor& path) {
std::set<int64_t> rows; std::set<int64_t> rows;
const int64_t* paths = path.data<int64_t>(); const int64_t* paths = path.data<int64_t>();
for (int64_t i = 0; i < path.numel(); ++i) { for (int64_t i = 0; i < path.numel(); ++i) {
...@@ -57,14 +57,17 @@ template <typename DeviceContext, typename T> ...@@ -57,14 +57,17 @@ template <typename DeviceContext, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X")); auto& in = GET_DATA_SAFELY(ctx.Input<LoDTensor>("X"), "Input", "X",
auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W")); "HierarchicalSigmoid");
auto* path = ctx.Input<framework::LoDTensor>("PathTable"); auto& w = GET_DATA_SAFELY(ctx.Input<LoDTensor>("W"), "Input", "W",
auto* code = ctx.Input<framework::LoDTensor>("PathCode"); "HierarchicalSigmoid");
auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label")); auto* path = ctx.Input<LoDTensor>("PathTable");
auto* bias = ctx.Input<framework::LoDTensor>("Bias"); auto* code = ctx.Input<LoDTensor>("PathCode");
auto* out = ctx.Output<framework::LoDTensor>("Out"); auto& label = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Label"), "Input",
auto* pre_out = ctx.Output<framework::LoDTensor>("PreOut"); "Label", "HierarchicalSigmoid");
auto* bias = ctx.Input<LoDTensor>("Bias");
auto* out = ctx.Output<LoDTensor>("Out");
auto* pre_out = ctx.Output<LoDTensor>("PreOut");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes")); size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
// for remote prefetch // for remote prefetch
...@@ -75,7 +78,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -75,7 +78,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
int64_t code_length = int64_t code_length =
path ? path->dims()[1] : math::FindLastSet(num_classes - 1); path ? path->dims()[1] : math::FindLastSet(num_classes - 1);
int64_t batch_size = in.dims()[0]; int64_t batch_size = in.dims()[0];
framework::LoDTensor sum; LoDTensor sum;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* pre_out_data = pre_out->mutable_data<T>( auto* pre_out_data = pre_out->mutable_data<T>(
framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
...@@ -89,11 +92,11 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { ...@@ -89,11 +92,11 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code; std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) { if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(num_classes, bit_code.reset(new math::MatrixBitCodeFunctor<T>(
label.data<int64_t>())); num_classes, label.template data<int64_t>()));
} else { } else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(*path, *code, bit_code.reset(new math::MatrixBitCodeFunctor<T>(
label.data<int64_t>())); *path, *code, label.template data<int64_t>()));
} }
std::vector<int64_t> sum_dims({batch_size, 1UL}); std::vector<int64_t> sum_dims({batch_size, 1UL});
...@@ -126,20 +129,24 @@ template <typename DeviceContext, typename T> ...@@ -126,20 +129,24 @@ template <typename DeviceContext, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X")); auto& in = GET_DATA_SAFELY(ctx.Input<LoDTensor>("X"), "Input", "X",
auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W")); "HierarchicalSigmoidGrad");
auto* path = ctx.Input<framework::LoDTensor>("PathTable"); auto& w = GET_DATA_SAFELY(ctx.Input<LoDTensor>("W"), "Input", "W",
auto* code = ctx.Input<framework::LoDTensor>("PathCode"); "HierarchicalSigmoidGrad");
auto* in_grad = auto* path = ctx.Input<LoDTensor>("PathTable");
ctx.Output<framework::LoDTensor>(framework::GradVarName("X")); auto* code = ctx.Input<LoDTensor>("PathCode");
auto* in_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
bool is_sparse = ctx.Attr<bool>("is_sparse"); bool is_sparse = ctx.Attr<bool>("is_sparse");
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero; math::SetConstant<DeviceContext, T> zero;
auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label")); auto& label = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Label"), "Input",
auto& pre_out = detail::Ref(ctx.Input<framework::LoDTensor>("PreOut")); "Label", "HierarchicalSigmoidGrad");
auto& out_grad = detail::Ref( auto& pre_out = GET_DATA_SAFELY(ctx.Input<LoDTensor>("PreOut"), "Input",
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))); "PreOut", "HierarchicalSigmoidGrad");
framework::LoDTensor pre_out_grad; auto& out_grad = GET_DATA_SAFELY(
ctx.Input<LoDTensor>(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "HierarchicalSigmoidGrad");
LoDTensor pre_out_grad;
pre_out_grad.mutable_data<T>(pre_out.dims(), ctx.GetPlace()); pre_out_grad.mutable_data<T>(pre_out.dims(), ctx.GetPlace());
in_grad->mutable_data<T>(ctx.GetPlace()); in_grad->mutable_data<T>(ctx.GetPlace());
...@@ -154,11 +161,11 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -154,11 +161,11 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code; std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) { if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(num_classes, bit_code.reset(new math::MatrixBitCodeFunctor<T>(
label.data<int64_t>())); num_classes, label.template data<int64_t>()));
} else { } else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(*path, *code, bit_code.reset(new math::MatrixBitCodeFunctor<T>(
label.data<int64_t>())); *path, *code, label.template data<int64_t>()));
} }
// softrelu derivative // softrelu derivative
...@@ -166,7 +173,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -166,7 +173,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
auto* pre_out_grad_data = pre_out_grad.data<T>(); auto* pre_out_grad_data = pre_out_grad.data<T>();
auto* pre_out_data = pre_out.data<T>(); auto* pre_out_data = pre_out.template data<T>();
auto n = pre_out.numel(); auto n = pre_out.numel();
blas.VEXP(n, pre_out_data, pre_out_grad_data); blas.VEXP(n, pre_out_data, pre_out_grad_data);
blas.VINV(n, pre_out_grad_data, pre_out_grad_data); blas.VINV(n, pre_out_grad_data, pre_out_grad_data);
...@@ -174,7 +181,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -174,7 +181,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i]; pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i];
} }
bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b) bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b)
auto* out_grad_data = out_grad.data<T>(); auto* out_grad_data = out_grad.template data<T>();
int64_t dim0 = pre_out_grad.dims()[0]; int64_t dim0 = pre_out_grad.dims()[0];
int64_t dim1 = pre_out_grad.dims()[1]; int64_t dim1 = pre_out_grad.dims()[1];
...@@ -184,16 +191,14 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { ...@@ -184,16 +191,14 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
} }
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward. // be consistent with the clipping in forward.
auto* bias_grad = auto* bias_grad = ctx.Output<LoDTensor>(framework::GradVarName("Bias"));
ctx.Output<framework::LoDTensor>(framework::GradVarName("Bias"));
if (bias_grad) { if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace()); bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0)); zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code->AddGrad(pre_out_grad, bias_grad); bit_code->AddGrad(pre_out_grad, bias_grad);
} }
if (!is_sparse) { if (!is_sparse) {
auto* w_grad = auto* w_grad = ctx.Output<LoDTensor>(framework::GradVarName("W"));
ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
w_grad->mutable_data<T>(ctx.GetPlace()); w_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, w_grad, static_cast<T>(0.0)); zero(dev_ctx, w_grad, static_cast<T>(0.0));
bit_code->MulGradWeight(pre_out_grad, w_grad, in); bit_code->MulGradWeight(pre_out_grad, w_grad, in);
......
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
...@@ -95,13 +94,15 @@ class LoDTensorToArrayOp : public framework::OperatorBase { ...@@ -95,13 +94,15 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input %s", auto &x = GET_DATA_SAFELY(scope.FindVar(Input("X")), "Input", "X",
Input("X")) "LoDTensorToArray")
.Get<framework::LoDTensor>(); .Get<framework::LoDTensor>();
auto &rank_table = detail::Ref(scope.FindVar(Input("RankTable"))) auto &rank_table = GET_DATA_SAFELY(scope.FindVar(Input("RankTable")),
"Input", "RankTable", "LoDTensorToArray")
.Get<framework::LoDRankTable>(); .Get<framework::LoDRankTable>();
auto &out = *detail::Ref(scope.FindVar(Output("Out"))) auto &out = *(GET_DATA_SAFELY(scope.FindVar(Output("Out")), "Output", "Out",
.GetMutable<framework::LoDTensorArray>(); "LoDTensorToArray")
.GetMutable<framework::LoDTensorArray>());
auto &items = rank_table.items(); auto &items = rank_table.items();
auto max_seq_len = items[0].length; auto max_seq_len = items[0].length;
auto rank_level = rank_table.level(); auto rank_level = rank_table.level();
......
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
...@@ -58,10 +57,10 @@ template <typename DeviceContext, typename T> ...@@ -58,10 +57,10 @@ template <typename DeviceContext, typename T>
class MatMulKernel : public framework::OpKernel<T> { class MatMulKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto &x = auto &x = GET_DATA_SAFELY(context.Input<framework::Tensor>("X"), "Input",
detail::Ref(context.Input<framework::Tensor>("X"), "Cannot find X"); "X", "MatMul");
auto &y = auto &y = GET_DATA_SAFELY(context.Input<framework::Tensor>("Y"), "Input",
detail::Ref(context.Input<framework::Tensor>("Y"), "Cannot find Y"); "Y", "MatMul");
auto *out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
......
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
......
...@@ -128,7 +128,6 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> { ...@@ -128,7 +128,6 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
framework::ToTypeName(param_var->Type()))); framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
using paddle::operators::detail::Ref;
int64_t min_row_size_to_use_multithread = int64_t min_row_size_to_use_multithread =
ctx.Attr<int64_t>("min_row_size_to_use_multithread"); ctx.Attr<int64_t>("min_row_size_to_use_multithread");
......
...@@ -20,7 +20,6 @@ limitations under the License. */ ...@@ -20,7 +20,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
...@@ -384,7 +383,6 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -384,7 +383,6 @@ class AdamOpKernel : public framework::OpKernel<T> {
framework::ToTypeName(param_var->Type())); framework::ToTypeName(param_var->Type()));
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
using paddle::operators::detail::Ref;
int64_t min_row_size_to_use_multithread = int64_t min_row_size_to_use_multithread =
ctx.Attr<int64_t>("min_row_size_to_use_multithread"); ctx.Attr<int64_t>("min_row_size_to_use_multithread");
......
...@@ -17,7 +17,6 @@ limitations under the License. */ ...@@ -17,7 +17,6 @@ limitations under the License. */
#include <Eigen/Dense> #include <Eigen/Dense>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
...@@ -185,30 +184,32 @@ class LambOpKernel : public framework::OpKernel<T> { ...@@ -185,30 +184,32 @@ class LambOpKernel : public framework::OpKernel<T> {
framework::ToTypeName(param_var->Type())); framework::ToTypeName(param_var->Type()));
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
using paddle::operators::detail::Ref;
T weight_decay = static_cast<T>(ctx.Attr<float>("weight_decay")); T weight_decay = static_cast<T>(ctx.Attr<float>("weight_decay"));
T beta1 = static_cast<T>(ctx.Attr<float>("beta1")); T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
T beta2 = static_cast<T>(ctx.Attr<float>("beta2")); T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon")); T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto& param = Ref(ctx.Input<LoDTensor>("Param"), "Must set Param."); auto& param = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Param"), "Input",
"Param", "Lamb");
auto* grad_var = ctx.InputVar("Grad"); auto* grad_var = ctx.InputVar("Grad");
auto& mom1 = Ref(ctx.Input<LoDTensor>("Moment1"), "Must set Moment1."); auto& mom1 = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment1"), "Input",
auto& mom2 = Ref(ctx.Input<LoDTensor>("Moment2"), "Must set Moment2."); "Moment1", "Lamb");
auto& lr = auto& mom2 = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment2"), "Input",
Ref(ctx.Input<LoDTensor>("LearningRate"), "Must set LearningRate."); "Moment2", "Lamb");
auto& lr = GET_DATA_SAFELY(ctx.Input<LoDTensor>("LearningRate"), "Input",
auto& beta1_pow = "LearningRate", "Lamb");
Ref(ctx.Input<LoDTensor>("Beta1Pow"), "Must set Beta1Pow.");
auto& beta2_pow = auto& beta1_pow = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Beta1Pow"), "Input",
Ref(ctx.Input<LoDTensor>("Beta2Pow"), "Must set Beta2Pow."); "Beta1Pow", "Lamb");
auto& beta2_pow = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Beta2Pow"), "Input",
auto& param_out = "Beta2Pow", "Lamb");
Ref(ctx.Output<LoDTensor>("ParamOut"), "Must set ParamOut.");
auto& mom1_out = auto& param_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("ParamOut"),
Ref(ctx.Output<LoDTensor>("Moment1Out"), "Must set Moment1Out."); "Output", "ParamOut", "Lamb");
auto& mom2_out = auto& mom1_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment1Out"),
Ref(ctx.Output<LoDTensor>("Moment2Out"), "Must set Moment1Out."); "Output", "Moment1Out", "Lamb");
auto& mom2_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment2Out"),
"Output", "Moment2Out", "Lamb");
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, param.numel()); platform::ForRange<DeviceContext> for_range(dev_ctx, param.numel());
...@@ -217,7 +218,7 @@ class LambOpKernel : public framework::OpKernel<T> { ...@@ -217,7 +218,7 @@ class LambOpKernel : public framework::OpKernel<T> {
// Update moments // Update moments
if (grad_var->IsType<framework::LoDTensor>()) { if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad."); auto& grad = *ctx.Input<LoDTensor>("Grad");
LambMomentUpdateFunctor<T> moment_update_functor( LambMomentUpdateFunctor<T> moment_update_functor(
weight_decay, beta1, beta2, epsilon, beta1_pow.template data<T>(), weight_decay, beta1, beta2, epsilon, beta1_pow.template data<T>(),
...@@ -229,8 +230,8 @@ class LambOpKernel : public framework::OpKernel<T> { ...@@ -229,8 +230,8 @@ class LambOpKernel : public framework::OpKernel<T> {
trust_ratio_div.template data<T>()); trust_ratio_div.template data<T>());
for_range(moment_update_functor); for_range(moment_update_functor);
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
auto& grad = auto& grad = GET_DATA_SAFELY(ctx.Input<framework::SelectedRows>("Grad"),
Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad."); "Input", "Grad", "Lamb");
if (grad.rows().size() == 0) { if (grad.rows().size() == 0) {
VLOG(3) << "grad row size is 0!!"; VLOG(3) << "grad row size is 0!!";
return; return;
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -152,10 +151,11 @@ class RandomCropKernel : public framework::OpKernel<T> { ...@@ -152,10 +151,11 @@ class RandomCropKernel : public framework::OpKernel<T> {
public: public:
virtual void Compute(const framework::ExecutionContext& ctx) const { virtual void Compute(const framework::ExecutionContext& ctx) const {
int64_t seed = 0; int64_t seed = 0;
auto& seed_tensor = detail::Ref(ctx.Input<framework::LoDTensor>("Seed")); auto& seed_tensor = GET_DATA_SAFELY(ctx.Input<framework::LoDTensor>("Seed"),
"Input", "Seed", "RandomCrop");
if (seed_tensor.IsInitialized()) { if (seed_tensor.IsInitialized()) {
if (platform::is_cpu_place(seed_tensor.place())) { if (platform::is_cpu_place(seed_tensor.place())) {
seed = *seed_tensor.data<int64_t>(); seed = *seed_tensor.template data<int64_t>();
} else { } else {
LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify " LOG(WARNING) << "It is slow to place seed in GPU memory. Please verify "
"your program"; "your program";
...@@ -169,13 +169,15 @@ class RandomCropKernel : public framework::OpKernel<T> { ...@@ -169,13 +169,15 @@ class RandomCropKernel : public framework::OpKernel<T> {
seed = ctx.Attr<int>("startup_seed"); seed = ctx.Attr<int>("startup_seed");
} }
auto shape = ctx.Attr<std::vector<int>>("shape"); auto shape = ctx.Attr<std::vector<int>>("shape");
auto& x = detail::Ref(ctx.Input<framework::LoDTensor>("X")); auto& x = GET_DATA_SAFELY(ctx.Input<framework::LoDTensor>("X"), "Input",
auto& out = detail::Ref(ctx.Output<framework::LoDTensor>("Out")); "X", "RandomCrop");
auto& out = GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Out"),
"Output", "Out", "RandomCrop");
int num_batchsize_dims = x.dims().size() - shape.size(); int num_batchsize_dims = x.dims().size() - shape.size();
RandomCropFunctor<DeviceContext, T> functor( RandomCropFunctor<DeviceContext, T> functor(
x.data<T>(), out.mutable_data<T>(ctx.GetPlace()), x.dims(), out.dims(), x.template data<T>(), out.template mutable_data<T>(ctx.GetPlace()),
num_batchsize_dims, seed); x.dims(), out.dims(), num_batchsize_dims, seed);
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), ctx.template device_context<DeviceContext>(),
functor.prod_batchsize_dims_); functor.prod_batchsize_dims_);
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h" #include "paddle/fluid/operators/reader/reader_op_registry.h"
namespace paddle { namespace paddle {
...@@ -171,8 +170,11 @@ void CustomReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) { ...@@ -171,8 +170,11 @@ void CustomReader::ReadNextImpl(std::vector<framework::LoDTensor>* out) {
// 3. Copy LoDTensors from sink variables to out. // 3. Copy LoDTensors from sink variables to out.
out->resize(sink_var_names_.size()); out->resize(sink_var_names_.size());
for (size_t i = 0; i < sink_var_names_.size(); ++i) { for (size_t i = 0; i < sink_var_names_.size(); ++i) {
const auto& tensor = detail::Ref(exe_scope->FindVar(sink_var_names_[i])) auto* var = exe_scope->FindVar(sink_var_names_[i]);
.Get<framework::LoDTensor>(); PADDLE_ENFORCE_NOT_NULL(var, platform::errors::NotFound(
"The variable %s is not in current scope.",
sink_var_names_[i]));
const auto& tensor = var->Get<framework::LoDTensor>();
framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]); framework::TensorCopySync(tensor, platform::CPUPlace(), &(*out)[i]);
} }
scope_.DeleteScope(exe_scope); scope_.DeleteScope(exe_scope);
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
...@@ -96,8 +95,8 @@ class ReadOp : public framework::OperatorBase { ...@@ -96,8 +95,8 @@ class ReadOp : public framework::OperatorBase {
const platform::Place& dev_place) const override { const platform::Place& dev_place) const override {
VLOG(3) << "read op in"; VLOG(3) << "read op in";
framework::ReaderHolder* reader = framework::ReaderHolder* reader =
detail::Ref(scope.FindVar(Input("Reader")), GET_DATA_SAFELY(scope.FindVar(Input("Reader")), "Input", "Reader",
"Cannot find reader variable %s", Input("Reader")) "Read")
.GetMutable<framework::ReaderHolder>(); .GetMutable<framework::ReaderHolder>();
std::vector<std::string> out_arg_names = Outputs("Out"); std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins; std::vector<framework::LoDTensor> ins;
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
...@@ -78,18 +77,16 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase { ...@@ -78,18 +77,16 @@ class ReorderLoDTensorByRankTableBase : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto &x = auto &x = GET_DATA_SAFELY(scope.FindVar(Input("X")), "Input", "X",
detail::Ref(scope.FindVar(Input("X")), "ReorderLoDTensorByRankTable")
"Cannot find input lod tensor variable %s", Input("X"))
.Get<framework::LoDTensor>(); .Get<framework::LoDTensor>();
auto &rank_table = detail::Ref(scope.FindVar(Input("RankTable")), auto &rank_table =
"Cannot find input rank table variable %s", GET_DATA_SAFELY(scope.FindVar(Input("RankTable")), "Input", "RankTable",
Input("RankTable")) "ReorderLoDTensorByRankTable")
.Get<framework::LoDRankTable>(); .Get<framework::LoDRankTable>();
auto &out = auto &out = *(GET_DATA_SAFELY(scope.FindVar(Output("Out")), "Output", "Out",
*detail::Ref(scope.FindVar(Output("Out")), "ReorderLoDTensorByRankTable")
"Cannot find output lod tensor variable %s", Output("Out")) .GetMutable<framework::LoDTensor>());
.GetMutable<framework::LoDTensor>();
out.Resize(x.dims()); out.Resize(x.dims());
out.mutable_data(x.place(), x.type()); out.mutable_data(x.place(), x.type());
......
...@@ -17,8 +17,6 @@ limitations under the License. */ ...@@ -17,8 +17,6 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
#include <vector> #include <vector>
#include "boost/optional.hpp" #include "boost/optional.hpp"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
namespace paddle { namespace paddle {
...@@ -47,16 +46,28 @@ inline framework::LoD ConcatLoD(const Container &xs, ...@@ -47,16 +46,28 @@ inline framework::LoD ConcatLoD(const Container &xs,
lod.emplace_back(result); lod.emplace_back(result);
return lod; return lod;
} }
template <typename T, typename... ARGS>
inline std::vector<std::reference_wrapper<T>> GetDataVectorSafely(
const std::vector<T *> &vec, ARGS &&... args) {
std::vector<std::reference_wrapper<T>> result;
result.reserve(vec.size());
for (auto *ptr : vec) {
PADDLE_ENFORCE_NOT_NULL(ptr, platform::errors::InvalidArgument(
"The input variable X contains nullptr."));
result.emplace_back(*ptr);
}
return result;
}
} // namespace detail } // namespace detail
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class SeqConcatKernel : public framework::OpKernel<T> { class SeqConcatKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto xs = detail::VectorRef(context.MultiInput<framework::LoDTensor>("X"), auto xs = detail::GetDataVectorSafely(
"Cannot find multiple input X"); context.MultiInput<framework::LoDTensor>("X"));
auto &out = detail::Ref(context.Output<framework::LoDTensor>("Out"), auto &out = *context.Output<framework::LoDTensor>("Out");
"Cannot find output");
size_t lod_size = 0; size_t lod_size = 0;
for (auto &x : xs) { for (auto &x : xs) {
...@@ -141,9 +152,9 @@ class SeqConcatGradKernel : public framework::OpKernel<T> { ...@@ -141,9 +152,9 @@ class SeqConcatGradKernel : public framework::OpKernel<T> {
math::SplitFunctor<DeviceContext, T> functor; math::SplitFunctor<DeviceContext, T> functor;
functor(context.template device_context<DeviceContext>(), functor(context.template device_context<DeviceContext>(),
detail::Ref( GET_DATA_SAFELY(
context.Input<framework::Tensor>(framework::GradVarName("Out")), context.Input<framework::Tensor>(framework::GradVarName("Out")),
"Sequence Concat OG must be set"), "Input", framework::GradVarName("Out"), "SeqConcatGrad"),
sliced_x_ptr, 0, &sliced_dx_ptr); sliced_x_ptr, 0, &sliced_dx_ptr);
} }
}; };
......
...@@ -19,7 +19,6 @@ limitations under the License. */ ...@@ -19,7 +19,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -18,7 +18,6 @@ limitations under the License. */ ...@@ -18,7 +18,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/var_type_inference.h" #include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
......
...@@ -350,6 +350,51 @@ struct EnforceNotMet : public std::exception { ...@@ -350,6 +350,51 @@ struct EnforceNotMet : public std::exception {
/** EXTENDED TOOL FUNCTIONS WITH CHECKING **/ /** EXTENDED TOOL FUNCTIONS WITH CHECKING **/
/*
* Summary: This macro is used to get Variable or internal type
* data (such as LoDTensor or SelectedRows) of the Input and
* Output in op, generally used when call scope.FindVar(Input/
* Output("Name")) or ctx.Input<LoDTensor>().
* Firstly this macro check whether the obtained pointer is null,
* and then return data if it is not null.
*
* Note: This macro is only suitable for specific scenarios and
* does not intended to be widely used. If it cannot meet the
* requirements, please use other PADDLE_ENFORCE** check macro.
*
* Parameters:
*     __PTR: pointer
* __ROLE: (string), Input or Output
* __NAME: (string), Input or Output name
* __OP_TYPE: (string), the op type
*  
* Return: The data pointed to by the pointer.
*
* Examples:
* GET_DATA_SAFELY(ctx.Input<LoDTensor>("X"), "Input", "X", "Mul");
*/
#define GET_DATA_SAFELY(__PTR, __ROLE, __NAME, __OP_TYPE) \
(([&]() -> std::add_lvalue_reference<decltype(*(__PTR))>::type { \
auto* ptr = (__PTR); \
if (UNLIKELY(nullptr == ptr)) { \
__THROW_ERROR_INTERNAL__( \
"%s\n [Hint: pointer " #__PTR " should not be null.]", \
paddle::platform::errors::NotFound( \
"Unable to get %s data of %s %s in operator %s. " \
"Possible reasons are:\n" \
" 1. The %s is not the %s of operator %s;\n" \
" 2. The %s has no corresponding variable passed in;\n" \
" 3. The %s corresponding variable is not initialized.", \
paddle::platform::demangle( \
typeid(std::add_lvalue_reference<decltype(*ptr)>::type) \
.name()), \
__ROLE, __NAME, __OP_TYPE, __NAME, __ROLE, __OP_TYPE, __NAME, \
__NAME) \
.ToString()); \
} \
return *ptr; \
})())
/* /*
* Summary: This macro is used to check whether op has specified * Summary: This macro is used to check whether op has specified
* Input or Output Variables. Because op's Input and Output * Input or Output Variables. Because op's Input and Output
......
...@@ -362,6 +362,22 @@ TEST(enforce, cannot_to_string_type) { ...@@ -362,6 +362,22 @@ TEST(enforce, cannot_to_string_type) {
PADDLE_ENFORCE_NE(list.begin(), list.end()); PADDLE_ENFORCE_NE(list.begin(), list.end());
} }
TEST(GET_DATA_SAFELY_MACRO, SUCCESS) {
int* a = new int(10);
GET_DATA_SAFELY(a, "Input", "X", "dummy");
}
TEST(GET_DATA_SAFELY_MACRO, FAIL) {
bool caught_exception = false;
try {
int* a = nullptr;
GET_DATA_SAFELY(a, "Input", "X", "dummy");
} catch (paddle::platform::EnforceNotMet& error) {
caught_exception = true;
}
EXPECT_TRUE(caught_exception);
}
TEST(OP_INOUT_CHECK_MACRO, SUCCESS) { TEST(OP_INOUT_CHECK_MACRO, SUCCESS) {
OP_INOUT_CHECK(true, "Input", "X", "dummy"); OP_INOUT_CHECK(true, "Input", "X", "dummy");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册