未验证 提交 856c26fa 编写于 作者: D dzhwinter 提交者: GitHub

fix elementwise (#13146)

上级 4fa3cee5
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <iterator>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -94,8 +95,11 @@ class RowwiseTransformIterator; ...@@ -94,8 +95,11 @@ class RowwiseTransformIterator;
template <typename T, typename DeviceContext> template <typename T, typename DeviceContext>
class MidWiseTransformIterator; class MidWiseTransformIterator;
// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17
template <typename T> template <typename T>
class RowwiseTransformIterator<T, platform::CPUDeviceContext> { class RowwiseTransformIterator<T, platform::CPUDeviceContext>
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
T *, T &> {
public: public:
RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {} RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
...@@ -126,7 +130,9 @@ class RowwiseTransformIterator<T, platform::CPUDeviceContext> { ...@@ -126,7 +130,9 @@ class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
}; };
template <typename T> template <typename T>
class MidWiseTransformIterator<T, platform::CPUDeviceContext> { class MidWiseTransformIterator<T, platform::CPUDeviceContext>
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
T *, T &> {
public: public:
MidWiseTransformIterator(const T *ptr, int n, int post) MidWiseTransformIterator(const T *ptr, int n, int post)
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
...@@ -479,8 +485,13 @@ void ElemwiseGradComputeNoBroadcast( ...@@ -479,8 +485,13 @@ void ElemwiseGradComputeNoBroadcast(
const framework::Tensor &dout, int axis, framework::Tensor *dx, const framework::Tensor &dout, int axis, framework::Tensor *dx,
framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) { framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
size_t N = static_cast<size_t>(framework::product(x_dim)); size_t N = static_cast<size_t>(framework::product(x_dim));
#if !defined(_WIN32)
platform::ForRange<DeviceContext> for_range( platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), N); ctx.template device_context<DeviceContext>(), N);
#else
platform::ForRange<DeviceContext> for_range(
ctx.device_context<DeviceContext>(), N);
#endif // !_WIN32
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{ for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op, x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()), dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
...@@ -633,13 +644,13 @@ void ElementwiseGradCompute(const framework::ExecutionContext &ctx, ...@@ -633,13 +644,13 @@ void ElementwiseGradCompute(const framework::ExecutionContext &ctx,
template <typename Functor, typename DeviceContext, typename T, template <typename Functor, typename DeviceContext, typename T,
typename OutType = T> typename OutType = T>
void ElementwiseComputeEx(const framework::ExecutionContext &ctx, void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
const framework::Tensor *x, const framework::Tensor *x,
const framework::Tensor *y, int axis, Functor func, const framework::Tensor *y, int axis, Functor func,
framework::Tensor *z) { framework::Tensor *z) {
TransformFunctor<Functor, T, DeviceContext, OutType> functor( TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), func); x, y, z, ctx.template device_context<DeviceContext>(), func);
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims_untrimed = y->dims(); auto y_dims_untrimed = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(), PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册