未验证 提交 10d9ab4b 编写于 作者: N Noel 提交者: GitHub

[pnorm] Optimize p_norm op for special cases (#37685)

上级 3a339cc0
...@@ -21,7 +21,10 @@ limitations under the License. */ ...@@ -21,7 +21,10 @@ limitations under the License. */
namespace cub = hipcub; namespace cub = hipcub;
#endif #endif
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/p_norm_op.h" #include "paddle/fluid/operators/p_norm_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
...@@ -56,87 +59,94 @@ __device__ __forceinline__ double inline_pow(double base, double exponent) { ...@@ -56,87 +59,94 @@ __device__ __forceinline__ double inline_pow(double base, double exponent) {
return pow(base, exponent); return pow(base, exponent);
} }
template <typename T, int BlockDim> struct IdentityFunctor {
__global__ void Pnorm(const T* x, const int pre, HOSTDEVICE explicit inline IdentityFunctor() {}
const int axis_n, // dim in axis HOSTDEVICE explicit inline IdentityFunctor(int n) {}
const int post, float porder, T* out_norm) { template <typename T>
using MT = typename details::MPTypeTrait<T>::Type; HOSTDEVICE inline T operator()(const T& x) const {
typedef cub::BlockReduce<MT, BlockDim> BlockReduce; return static_cast<T>(x);
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
auto porder_t = static_cast<MT>(porder);
auto porder_inv = static_cast<MT>(1.0 / porder);
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
MT sum = static_cast<MT>(0.0);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const MT x_ij = static_cast<MT>(x[base + j * post]);
sum += inline_pow(inline_abs(x_ij), porder_t);
} }
MT reduce_result = BlockReduce(temp_storage).Sum(sum); };
if (threadIdx.x == 0)
out_norm[i] = static_cast<T>(inline_pow(reduce_result, porder_inv)); struct NonzeroFunctor {
HOSTDEVICE explicit inline NonzeroFunctor() {}
HOSTDEVICE explicit inline NonzeroFunctor(int n) {}
template <typename T>
HOSTDEVICE inline T operator()(const T& x) const {
return static_cast<T>(static_cast<double>(x) != 0);
} }
} };
template <typename T, int BlockDim> struct AbsFunctor {
__global__ void ZeorNorm(const T* x, const int pre, HOSTDEVICE explicit inline AbsFunctor() {}
const int axis_n, // dim in axis HOSTDEVICE explicit inline AbsFunctor(int n) {}
const int post, T* out_norm) { template <typename T>
using MT = typename details::MPTypeTrait<T>::Type; HOSTDEVICE inline T operator()(const T& x) const {
typedef cub::BlockReduce<MT, BlockDim> BlockReduce; return static_cast<T>(inline_abs(x));
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
MT sum = static_cast<MT>(0.0);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const MT x_ij = static_cast<MT>(x[base + j * post]);
sum += static_cast<MT>(static_cast<double>(x_ij) != 0);
} }
MT reduce_result = BlockReduce(temp_storage).Sum(sum); };
if (threadIdx.x == 0) out_norm[i] = static_cast<T>(reduce_result);
template <typename Tx, typename Ty = Tx>
struct UnsignedPowFunctor {
HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) {
this->porder = porder;
} }
} HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(inline_pow(inline_abs(x), static_cast<Tx>(porder)));
}
float porder;
};
template <typename Tx, typename Ty = Tx>
struct PowFunctor {
HOSTDEVICE explicit inline PowFunctor(float porder) { this->porder = porder; }
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(inline_pow(x, static_cast<Tx>(porder)));
}
float porder;
};
template <typename T, int BlockDim> template <typename Tx, typename Ty = Tx>
__global__ void InfNorm(const T* x, const int pre, struct AbsAndMin {
const int axis_n, // dim in axis using Transformer = AbsFunctor;
const int post, T* out_norm) { using MT = typename details::MPTypeTrait<Ty>::Type;
typedef cub::BlockReduce<T, BlockDim> BlockReduce; inline Ty initial() {
__shared__ typename BlockReduce::TempStorage temp_storage; return static_cast<Ty>(std::numeric_limits<MT>::infinity());
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
T cur_max = inline_abs(x[base]);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
T x_ij_abs = inline_abs(x[base + j * post]);
if (cur_max < x_ij_abs) cur_max = x_ij_abs;
} }
T reduce_result = BlockReduce(temp_storage).Reduce(cur_max, cub::Max()); __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
if (threadIdx.x == 0) out_norm[i] = reduce_result; return (a < b) ? a : b;
} }
} };
template <typename T, int BlockDim> template <typename Tx, typename Ty = Tx>
__global__ void NegInfNorm(const T* x, const int pre, struct AbsAndMax {
const int axis_n, // dim in axis using Transformer = AbsFunctor;
const int post, T* out_norm) { using MT = typename details::MPTypeTrait<Ty>::Type;
typedef cub::BlockReduce<T, BlockDim> BlockReduce; inline Ty initial() {
__shared__ typename BlockReduce::TempStorage temp_storage; return static_cast<Ty>(-std::numeric_limits<MT>::infinity());
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
T cur_min = inline_abs(x[base]);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
T x_ij_abs = inline_abs(x[base + j * post]);
if (cur_min > x_ij_abs) cur_min = x_ij_abs;
} }
T reduce_result = BlockReduce(temp_storage).Reduce(cur_min, cub::Min()); __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
if (threadIdx.x == 0) out_norm[i] = reduce_result; return (a > b) ? a : b;
} }
} };
template <typename Tx, typename Ty = Tx>
struct NonzeroAndSum {
using Transformer = NonzeroFunctor;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};
template <typename Tx, typename Ty = Tx>
struct IdentityAndSum {
using Transformer = IdentityFunctor;
inline Ty initial() { return static_cast<Ty>(0.0f); }
__device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const {
return b + a;
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class PnormCUDAKernel : public framework::OpKernel<T> { class PnormCUDAKernel : public framework::OpKernel<T> {
...@@ -146,101 +156,83 @@ class PnormCUDAKernel : public framework::OpKernel<T> { ...@@ -146,101 +156,83 @@ class PnormCUDAKernel : public framework::OpKernel<T> {
auto* out_norm = ctx.Output<framework::Tensor>("Out"); auto* out_norm = ctx.Output<framework::Tensor>("Out");
const T* x = in_x->data<T>(); const T* x = in_x->data<T>();
T* norm = out_norm->mutable_data<T>(ctx.GetPlace()); T* norm = out_norm->mutable_data<T>(ctx.GetPlace());
auto xdim = in_x->dims(); auto xdim = in_x->dims();
auto ndim = out_norm->dims(); auto ndim = out_norm->dims();
float porder = ctx.Attr<float>("porder"); float porder = ctx.Attr<float>("porder");
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
bool asvector = ctx.Attr<bool>("asvector");
if (axis < 0) axis = xdim.size() + axis; if (axis < 0) axis = xdim.size() + axis;
int pre, n, post; std::vector<int> reduce_axis = {axis};
GetDims(xdim, axis, &pre, &n, &post, asvector);
auto& dev_ctx = ctx.cuda_device_context();
#ifdef __HIPCC__ auto stream = ctx.cuda_device_context().stream();
const int block = 256;
#else
const int block = 512;
#endif
int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); using MT = typename details::MPTypeTrait<T>::Type;
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(max_blocks, pre * post);
if (porder == 0) { if (porder == 0) {
ZeorNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post, TensorReduceFunctorImpl<T, T, NonzeroAndSum>(*in_x, out_norm, reduce_axis,
norm); stream);
} else if (porder == INFINITY) { } else if (porder == INFINITY) {
InfNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post, TensorReduceFunctorImpl<T, T, AbsAndMax>(*in_x, out_norm, reduce_axis,
norm); stream);
} else if (porder == -INFINITY) { } else if (porder == -INFINITY) {
NegInfNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, TensorReduceFunctorImpl<T, T, AbsAndMin>(*in_x, out_norm, reduce_axis,
post, norm); stream);
} else { } else {
Pnorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post, framework::Tensor tmp_x;
porder, norm); tmp_x.mutable_data<T>(xdim, ctx.GetPlace());
std::vector<const framework::Tensor*> ins = {in_x};
std::vector<framework::Tensor*> outs = {&tmp_x};
auto func = UnsignedPowFunctor<MT, T>(porder);
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, MT, T,
UnsignedPowFunctor<MT, T>>(
cuda_ctx, ins, &outs, func);
framework::Tensor tmp_y;
tmp_y.mutable_data<T>(ndim, ctx.GetPlace());
TensorReduceFunctorImpl<T, T, IdentityAndSum>(tmp_x, &tmp_y, reduce_axis,
stream);
const framework::Tensor* tmp_norm = &tmp_y;
ins = {tmp_norm};
outs = {out_norm};
auto func_inverse = UnsignedPowFunctor<MT, T>(1. / porder);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, MT, T,
UnsignedPowFunctor<MT, T>>(
cuda_ctx, ins, &outs, func_inverse);
} }
} }
}; };
template <typename T, int BlockDim> template <typename T>
__global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad, struct AbsMaxAndMinGradFunctor {
const float porder, const int pre, template <typename DeviceContext, typename X, typename Y, typename DX,
const int axis_n, const int post, const T eps, typename DY, typename Dim>
T* x_grad) { void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
using MT = typename details::MPTypeTrait<T>::Type; const Dim& dim, int size) {
// dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x) auto equals = ((*x).abs() == y->broadcast(dim));
int num = pre * post; auto ones = dx->constant(static_cast<T>(1.));
auto porder_grad = static_cast<MT>(porder - 1.0f); auto negs = dx->constant(static_cast<T>(-1.));
for (int i = blockIdx.x; i < num; i += gridDim.x) { auto zeros = dx->constant(static_cast<T>(0.));
__shared__ MT pnorm_i; auto positives = (*x) > zeros;
__shared__ MT yout_i; dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros) *
positives.select(ones, negs);
auto base = (i / post) * post * axis_n + (i % post);
if (threadIdx.x == 0) {
pnorm_i = static_cast<MT>(x_norm[i]);
yout_i = static_cast<MT>(y_grad[i]);
}
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
const MT x_ij = static_cast<MT>(inline_abs(x[index]));
x_grad[index] = static_cast<T>(
inline_pow(x_ij, porder_grad) /
(inline_pow(pnorm_i, porder_grad) + static_cast<MT>(eps)) * yout_i *
static_cast<MT>(inline_sign(x[index])));
}
}
}
template <typename T, int BlockDim>
__global__ void InfNormGradient(const T* x, const T* x_norm, const T* y_grad,
const int pre, const int axis_n, const int post,
T* x_grad) {
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
__shared__ T pnorm_i;
__shared__ T yout_i;
auto base = (i / post) * post * axis_n + (i % post);
if (threadIdx.x == 0) {
pnorm_i = x_norm[i];
yout_i = y_grad[i];
} }
__syncthreads(); };
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { template <typename T>
int index = base + j * post; struct PNormPostGradFunctor {
const T x_ij = inline_abs(x[index]); template <typename DeviceContext, typename X, typename Y, typename DX,
if (x_ij == pnorm_i) { typename DY, typename Dim>
x_grad[index] = static_cast<T>(inline_sign(x[index])) * yout_i; void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
} else { const Dim& dim, int size) {
x_grad[index] = static_cast<T>(0); auto ones = dx->constant(static_cast<T>(1.));
} auto negs = dx->constant(static_cast<T>(-1.));
} auto zeros = dx->constant(static_cast<T>(0.));
auto positives = (*x) > zeros;
dx->device(place) = (*dx) * dy->broadcast(dim) * y->broadcast(dim) *
positives.select(ones, negs);
} }
} };
template <typename DeviceContext, typename T, typename AttrType = T> template <typename DeviceContext, typename T, typename AttrType = T>
class PnormGradCUDAKernel : public framework::OpKernel<T> { class PnormGradCUDAKernel : public framework::OpKernel<T> {
...@@ -252,40 +244,40 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> { ...@@ -252,40 +244,40 @@ class PnormGradCUDAKernel : public framework::OpKernel<T> {
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out_dx = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* out_dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
T* dx = out_dx->mutable_data<T>(ctx.GetPlace()); T* dx = out_dx->mutable_data<T>(ctx.GetPlace());
const T* x = in_x->data<T>();
const T* x_norm = in_norm->data<T>();
const T* norm_dy = in_norm_dy->data<T>();
auto xdim = in_x->dims(); auto xdim = in_x->dims();
float porder = ctx.Attr<float>("porder"); float porder = ctx.Attr<float>("porder");
T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
bool asvector = ctx.Attr<bool>("asvector"); bool reduce_all = ((axis < 0) || (in_norm->numel() == 1));
if (axis < 0) axis = xdim.size() + axis; if (axis < 0) axis = xdim.size() + axis;
int pre, n, post; const std::vector<int> dims = {axis};
GetDims(xdim, axis, &pre, &n, &post, asvector);
auto& dev_ctx = ctx.cuda_device_context();
#ifdef __HIPCC__ auto& cuda_ctx = ctx.template device_context<DeviceContext>();
const int block = 256;
#else
const int block = 512;
#endif
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(max_blocks, pre * post);
if (porder == 0) { if (porder == 0) {
math::SetConstant<DeviceContext, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); set_zero(cuda_ctx, out_dx, static_cast<T>(0));
set_zero(dev_ctx, out_dx, static_cast<T>(0));
} else if (porder == INFINITY || porder == -INFINITY) { } else if (porder == INFINITY || porder == -INFINITY) {
InfNormGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>( LaunchReduceGradKernel<DeviceContext, T, AbsMaxAndMinGradFunctor<T>>(
x, x_norm, norm_dy, pre, n, post, dx); ctx, in_x, in_norm, in_norm_dy, out_dx, dims, reduce_all);
} else { } else {
PnormGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>( framework::Tensor tmp_norm;
x, x_norm, norm_dy, porder, pre, n, post, eps, dx); tmp_norm.mutable_data<T>(in_norm->dims(), ctx.GetPlace());
std::vector<const framework::Tensor*> ins = {in_norm};
std::vector<framework::Tensor*> outs = {&tmp_norm};
auto pow_functor = PowFunctor<T>(1. - porder);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T,
PowFunctor<T>>(cuda_ctx, ins, &outs,
pow_functor);
ins = {in_x};
outs = {out_dx};
auto unsigned_pow = UnsignedPowFunctor<T>(porder - 1.);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, T, T,
UnsignedPowFunctor<T>>(
cuda_ctx, ins, &outs, unsigned_pow);
const framework::Tensor* tmp_norm_const = &tmp_norm;
LaunchReduceGradKernel<DeviceContext, T, PNormPostGradFunctor<T>>(
ctx, in_x, tmp_norm_const, in_norm_dy, out_dx, dims, reduce_all);
} }
} }
}; };
......
...@@ -326,46 +326,14 @@ class BoolReduceKernel : public framework::OpKernel<OutT> { ...@@ -326,46 +326,14 @@ class BoolReduceKernel : public framework::OpKernel<OutT> {
} }
}; };
template <typename DeviceContext, typename T, typename Functor, template <typename DeviceContext, typename T, typename Functor>
bool kNoNeedBufferX = false, bool kNoNeedBufferY = false> void LaunchReduceGradKernel(const framework::ExecutionContext& context,
class ReduceGradKernel : public framework::OpKernel<T> { const framework::Tensor* input0,
public: const framework::Tensor* input1,
void ComputeFromInput(const Tensor* input2, const framework::Tensor* input2,
const framework::ExecutionContext& context) const { paddle::framework::Tensor* output,
bool reduce_all = context.Attr<bool>("reduce_all"); const std::vector<int>& dims,
auto dims = context.Attr<std::vector<int>>("dim"); bool reduce_all = false) {
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Out");
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
output->mutable_data<T>(context.GetPlace());
// The dims has full dim, set the reduce_all is True
const auto& input_dim_size = context.Input<Tensor>("X")->dims().size();
std::set<int> dims_set(dims.begin(), dims.end());
bool full_dim = true;
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = (reduce_all || full_dim);
// NOTE: EigenTensor::From() uses tensor->data()
// if op has NoNeedBufferVarsInferer, the corresponding kNoNeedBufferX or
// kNoNeedBufferY should set true
// and use fake var that has same dims.
if (kNoNeedBufferX) {
input0 = output;
}
if (kNoNeedBufferY) {
input1 = input2;
}
// NOTE(dengkaipeng): Out is unnecessary in some reduce kernel and
// not be set as Input in grad Maker, use Out_grad to replace here
if (!input1) input1 = input2;
if (reduce_all) { if (reduce_all) {
auto x = EigenVector<T>::Flatten(*input0); auto x = EigenVector<T>::Flatten(*input0);
auto x_reduce = EigenVector<T>::Flatten(*input1); auto x_reduce = EigenVector<T>::Flatten(*input1);
...@@ -383,33 +351,33 @@ class ReduceGradKernel : public framework::OpKernel<T> { ...@@ -383,33 +351,33 @@ class ReduceGradKernel : public framework::OpKernel<T> {
switch (rank) { switch (rank) {
case 1: case 1:
ReduceGradFunctor<DeviceContext, T, 1, Functor>( ReduceGradFunctor<DeviceContext, T, 1, Functor>(
context.template device_context<DeviceContext>(), *input0, context.template device_context<DeviceContext>(), *input0, *input1,
*input1, *input2, output, dims); *input2, output, dims);
break; break;
case 2: case 2:
ReduceGradFunctor<DeviceContext, T, 2, Functor>( ReduceGradFunctor<DeviceContext, T, 2, Functor>(
context.template device_context<DeviceContext>(), *input0, context.template device_context<DeviceContext>(), *input0, *input1,
*input1, *input2, output, dims); *input2, output, dims);
break; break;
case 3: case 3:
ReduceGradFunctor<DeviceContext, T, 3, Functor>( ReduceGradFunctor<DeviceContext, T, 3, Functor>(
context.template device_context<DeviceContext>(), *input0, context.template device_context<DeviceContext>(), *input0, *input1,
*input1, *input2, output, dims); *input2, output, dims);
break; break;
case 4: case 4:
ReduceGradFunctor<DeviceContext, T, 4, Functor>( ReduceGradFunctor<DeviceContext, T, 4, Functor>(
context.template device_context<DeviceContext>(), *input0, context.template device_context<DeviceContext>(), *input0, *input1,
*input1, *input2, output, dims); *input2, output, dims);
break; break;
case 5: case 5:
ReduceGradFunctor<DeviceContext, T, 5, Functor>( ReduceGradFunctor<DeviceContext, T, 5, Functor>(
context.template device_context<DeviceContext>(), *input0, context.template device_context<DeviceContext>(), *input0, *input1,
*input1, *input2, output, dims); *input2, output, dims);
break; break;
case 6: case 6:
ReduceGradFunctor<DeviceContext, T, 6, Functor>( ReduceGradFunctor<DeviceContext, T, 6, Functor>(
context.template device_context<DeviceContext>(), *input0, context.template device_context<DeviceContext>(), *input0, *input1,
*input1, *input2, output, dims); *input2, output, dims);
break; break;
default: default:
HandleLargeDimGrad<DeviceContext, T, Functor>(context, input0, input1, HandleLargeDimGrad<DeviceContext, T, Functor>(context, input0, input1,
...@@ -417,6 +385,51 @@ class ReduceGradKernel : public framework::OpKernel<T> { ...@@ -417,6 +385,51 @@ class ReduceGradKernel : public framework::OpKernel<T> {
break; break;
} }
} }
}
template <typename DeviceContext, typename T, typename Functor,
bool kNoNeedBufferX = false, bool kNoNeedBufferY = false>
class ReduceGradKernel : public framework::OpKernel<T> {
public:
void ComputeFromInput(const Tensor* input2,
const framework::ExecutionContext& context) const {
bool reduce_all = context.Attr<bool>("reduce_all");
auto dims = context.Attr<std::vector<int>>("dim");
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Out");
auto* output = context.Output<Tensor>(framework::GradVarName("X"));
output->mutable_data<T>(context.GetPlace());
// The dims has full dim, set the reduce_all is True
const auto& input_dim_size = context.Input<Tensor>("X")->dims().size();
std::set<int> dims_set(dims.begin(), dims.end());
bool full_dim = true;
for (auto i = 0; i < input_dim_size; i++) {
if (dims_set.find(i) == dims_set.end()) {
full_dim = false;
break;
}
}
reduce_all = (reduce_all || full_dim);
// NOTE: EigenTensor::From() uses tensor->data()
// if op has NoNeedBufferVarsInferer, the corresponding kNoNeedBufferX or
// kNoNeedBufferY should set true
// and use fake var that has same dims.
if (kNoNeedBufferX) {
input0 = output;
}
if (kNoNeedBufferY) {
input1 = input2;
}
const std::vector<int> const_dims = dims;
// NOTE(dengkaipeng): Out is unnecessary in some reduce kernel and
// not be set as Input in grad Maker, use Out_grad to replace here
if (!input1) input1 = input2;
LaunchReduceGradKernel<DeviceContext, T, Functor>(
context, input0, input1, input2, output, const_dims, reduce_all);
} }
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -186,7 +186,6 @@ register_unity_group(cc ...@@ -186,7 +186,6 @@ register_unity_group(cc
norm_op.cc norm_op.cc
one_hot_op.cc one_hot_op.cc
one_hot_v2_op.cc one_hot_v2_op.cc
p_norm_op.cc
pad2d_op.cc pad2d_op.cc
pad3d_op.cc pad3d_op.cc
pad_constant_like_op.cc pad_constant_like_op.cc
...@@ -468,7 +467,6 @@ register_unity_group(cu ...@@ -468,7 +467,6 @@ register_unity_group(cu
nll_loss_op.cu nll_loss_op.cu
norm_op.cu norm_op.cu
one_hot_op.cu one_hot_op.cu
p_norm_op.cu
pad2d_op.cu pad2d_op.cu
pad3d_op.cu pad3d_op.cu
pad_constant_like_op.cu pad_constant_like_op.cu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册