From 856c26faef0ba3cd2a3bd1b0446a9f57ef610a04 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 3 Sep 2018 15:38:27 +0800 Subject: [PATCH] fix elementwise (#13146) --- .../fluid/operators/elementwise_op_function.h | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/elementwise_op_function.h b/paddle/fluid/operators/elementwise_op_function.h index d5b9b2dac0..b1a399c22c 100644 --- a/paddle/fluid/operators/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise_op_function.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -94,8 +95,11 @@ class RowwiseTransformIterator; template class MidWiseTransformIterator; +// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17 template -class RowwiseTransformIterator { +class RowwiseTransformIterator + : public std::iterator { public: RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {} @@ -126,7 +130,9 @@ class RowwiseTransformIterator { }; template -class MidWiseTransformIterator { +class MidWiseTransformIterator + : public std::iterator { public: MidWiseTransformIterator(const T *ptr, int n, int post) : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} @@ -479,8 +485,13 @@ void ElemwiseGradComputeNoBroadcast( const framework::Tensor &dout, int axis, framework::Tensor *dx, framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { size_t N = static_cast(framework::product(x_dim)); +#if !defined(_WIN32) platform::ForRange for_range( ctx.template device_context(), N); +#else + platform::ForRange for_range( + ctx.device_context(), N); +#endif // !_WIN32 for_range(ElemwiseGradNoBroadcast{ x.data(), y.data(), out.data(), dout.data(), dx_op, dy_op, dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace()), @@ -633,13 +644,13 @@ void ElementwiseGradCompute(const framework::ExecutionContext &ctx, template + void ElementwiseComputeEx(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, int axis, Functor func, framework::Tensor *z) { TransformFunctor functor( x, y, z, ctx.template device_context(), func); - auto x_dims = x->dims(); auto y_dims_untrimed = y->dims(); PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(), -- GitLab