提交 76e188e5 编写于 作者: C chengduoZH

Add layer norm [GPU]

上级 71a70f20
...@@ -62,7 +62,7 @@ class CompareOpKernel ...@@ -62,7 +62,7 @@ class CompareOpKernel
z->mutable_data<T>(context.GetPlace()); z->mutable_data<T>(context.GetPlace());
int axis = context.Attr<int>("axis"); int axis = context.Attr<int>("axis");
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, axis, ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, axis,
z); Functor(), z);
} }
}; };
......
...@@ -35,7 +35,8 @@ class ElementwiseAddKernel : public framework::OpKernel<T> { ...@@ -35,7 +35,8 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z); ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
AddFunctor<T>(), z);
} }
}; };
......
...@@ -35,7 +35,8 @@ class ElementwiseDivKernel : public framework::OpKernel<T> { ...@@ -35,7 +35,8 @@ class ElementwiseDivKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z); ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
DivFunctor<T>(), z);
} }
}; };
......
...@@ -35,7 +35,8 @@ class ElementwiseMaxKernel : public framework::OpKernel<T> { ...@@ -35,7 +35,8 @@ class ElementwiseMaxKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z); ElementwiseComputeEx<MaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MaxFunctor<T>(), z);
} }
}; };
......
...@@ -35,7 +35,8 @@ class ElementwiseMinKernel : public framework::OpKernel<T> { ...@@ -35,7 +35,8 @@ class ElementwiseMinKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z); ElementwiseComputeEx<MinFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MinFunctor<T>(), z);
} }
}; };
......
...@@ -34,7 +34,8 @@ class ElementwiseMulKernel : public framework::OpKernel<T> { ...@@ -34,7 +34,8 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z); ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MulFunctor<T>(), z);
} }
}; };
......
...@@ -365,10 +365,10 @@ template <typename Functor, typename DeviceContext, typename T, ...@@ -365,10 +365,10 @@ template <typename Functor, typename DeviceContext, typename T,
typename OutType = T> typename OutType = T>
void ElementwiseComputeEx(const framework::ExecutionContext& ctx, void ElementwiseComputeEx(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* x,
const framework::Tensor* y, int axis, const framework::Tensor* y, int axis, Functor func,
framework::Tensor* z) { framework::Tensor* z) {
TransformFunctor<Functor, T, DeviceContext, OutType> functor( TransformFunctor<Functor, T, DeviceContext, OutType> functor(
x, y, z, ctx.template device_context<DeviceContext>(), Functor()); x, y, z, ctx.template device_context<DeviceContext>(), func);
auto x_dims = x->dims(); auto x_dims = x->dims();
auto y_dims = y->dims(); auto y_dims = y->dims();
......
...@@ -36,7 +36,8 @@ class ElementwisePowKernel : public framework::OpKernel<T> { ...@@ -36,7 +36,8 @@ class ElementwisePowKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<PowFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z); ElementwiseComputeEx<PowFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
PowFunctor<T>(), z);
} }
}; };
......
...@@ -34,7 +34,8 @@ class ElementwiseSubKernel : public framework::OpKernel<T> { ...@@ -34,7 +34,8 @@ class ElementwiseSubKernel : public framework::OpKernel<T> {
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis, z); ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
SubFunctor<T>(), z);
} }
}; };
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/layer_norm_op.h" #include "paddle/operators/layer_norm_op.h"
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -353,8 +355,9 @@ namespace ops = paddle::operators; ...@@ -353,8 +355,9 @@ namespace ops = paddle::operators;
REGISTER_OP(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker, REGISTER_OP(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker,
layer_norm_grad, ops::LayerNormGradOp); layer_norm_grad, ops::LayerNormGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
layer_norm, layer_norm, ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>); ops::LayerNormKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
layer_norm_grad, layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, float>); ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/operators/layer_norm_op.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
namespace {
template <typename T>
struct SubAndSquareFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return (a - b) * (a - b); }
};
template <typename T>
struct DivAndSqrtFunctor {
explicit DivAndSqrtFunctor(T epsilon) { epsilon_ = epsilon; }
inline HOSTDEVICE T operator()(T a, T b) const {
return a / (sqrt(b) + epsilon_);
}
private:
T epsilon_;
};
template <typename T>
struct MulFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
};
template <typename T>
struct AddFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct SubFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
};
template <typename T>
struct MulInvVarFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
return a * std::sqrt(1.0 / b);
}
};
} // namespace
template <typename DeviceContext, typename T>
class LayerNormCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto *scale = ctx.Input<Tensor>("Scale");
auto *bias = ctx.Input<Tensor>("Bias");
auto x = *ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Y");
auto *mean = ctx.Output<Tensor>("Mean");
auto *var = ctx.Output<Tensor>("Variance");
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto &x_dims = x.dims();
y->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
framework::DDim matrix_shape({left, right});
x.Resize(matrix_shape);
y->Resize(matrix_shape);
auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::RowwiseMean<DeviceContext, T> row_mean;
// functor-> get mean
row_mean(dev_ctx, x, mean);
// functor-> get variance
ElementwiseComputeEx<SubAndSquareFunctor<T>, DeviceContext, T>(
ctx, &x, mean, /*axis*/ 0, SubAndSquareFunctor<T>(), y);
row_mean(dev_ctx, *y, var);
// functor-> get norm_out
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &x, mean, /*axis*/ 0, SubFunctor<T>(), y);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, y, var, /*axis*/ 0, DivAndSqrtFunctor<T>(static_cast<T>(epsilon)),
y);
framework::DDim scale_shape({right});
if (scale) {
Tensor scale_matrix = *scale;
scale_matrix.Resize(scale_shape);
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, y, &scale_matrix, /*axis*/ 1, MulFunctor<T>(), y);
}
if (bias) {
Tensor bias_matrix = *bias;
bias_matrix.Resize(scale_shape);
ElementwiseComputeEx<AddFunctor<T>, DeviceContext, T>(
ctx, y, &bias_matrix, /*axis*/ 1, AddFunctor<T>(), y);
}
y->Resize(x_dims);
}
};
template <typename DeviceContext, typename T>
class LayerNormCUDAGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto x = *ctx.Input<Tensor>("X");
auto mean = *ctx.Input<Tensor>("Mean");
auto var = *ctx.Input<Tensor>("Variance");
auto scale = *ctx.Input<Tensor>("Scale");
auto d_y = *ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
const auto &x_dims = x.dims();
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
framework::DDim matrix_shape({left, right});
d_y.Resize(matrix_shape);
auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::ColwiseSum<DeviceContext, T> colwise_sum;
Tensor temp;
Tensor temp_norm;
if (d_scale || d_x) {
x.Resize(matrix_shape);
temp.mutable_data<T>(matrix_shape, ctx.GetPlace());
temp_norm.mutable_data<T>(matrix_shape, ctx.GetPlace());
// get x_norm
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &x, &mean, /*axis*/ 0, SubFunctor<T>(), &temp_norm);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, &var, /*axis*/ 0,
DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), &temp_norm);
}
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
colwise_sum(dev_ctx, d_y, d_bias);
}
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, &d_y, /*axis*/ 0, MulFunctor<T>(), &temp);
colwise_sum(dev_ctx, temp, d_scale);
}
if (d_x) {
framework::DDim vec_shape({left});
d_x->mutable_data<T>(ctx.GetPlace());
Tensor temp_vec;
temp_vec.mutable_data<T>(vec_shape, ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<DeviceContext>();
math::RowwiseMean<DeviceContext, T> row_mean;
if (d_scale) {
// dy_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &d_y, &scale, /*axis*/ 1, MulFunctor<T>(), &temp);
framework::Copy(temp, ctx.GetPlace(), ctx.device_context(), d_x);
// dy_dmean_dx
row_mean(dev_ctx, temp, &temp_vec);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor<T>(), d_x);
// dy_var_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp, &temp_norm, /*axis*/ 0, MulFunctor<T>(), &temp);
} else {
// dy_dx
framework::Copy(d_y, ctx.GetPlace(), ctx.device_context(), d_x);
// dy_dmean_dx
row_mean(dev_ctx, d_y, &temp_vec);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp_vec, /*axis*/ 0, SubFunctor<T>(), d_x);
// dy_var_dx
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &d_y, &temp_norm, /*axis*/ 0, MulFunctor<T>(), &temp);
}
// dy_var_dx
row_mean(dev_ctx, temp, &temp_vec);
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(
ctx, &temp_norm, &temp_vec, /*axis*/ 0, MulFunctor<T>(), &temp_norm);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, d_x, &temp_norm, /*axis*/ 0, SubFunctor<T>(), d_x);
ElementwiseComputeEx<DivAndSqrtFunctor<T>, DeviceContext, T>(
ctx, d_x, &var, /*axis*/ 0,
DivAndSqrtFunctor<T>(static_cast<T>(epsilon)), d_x);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
layer_norm,
ops::LayerNormCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormCUDAKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
layer_norm_grad,
ops::LayerNormCUDAGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormCUDAGradKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -331,6 +331,12 @@ template struct RowwiseAdd<platform::CPUDeviceContext, double>; ...@@ -331,6 +331,12 @@ template struct RowwiseAdd<platform::CPUDeviceContext, double>;
template struct ColwiseSum<platform::CPUDeviceContext, float>; template struct ColwiseSum<platform::CPUDeviceContext, float>;
template struct ColwiseSum<platform::CPUDeviceContext, double>; template struct ColwiseSum<platform::CPUDeviceContext, double>;
template struct RowwiseSum<platform::CPUDeviceContext, float>;
template struct RowwiseSum<platform::CPUDeviceContext, double>;
template struct RowwiseMean<platform::CPUDeviceContext, float>;
template struct RowwiseMean<platform::CPUDeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -325,6 +325,31 @@ void ColwiseSum<platform::CUDADeviceContext, double>::operator()( ...@@ -325,6 +325,31 @@ void ColwiseSum<platform::CUDADeviceContext, double>::operator()(
vector->data<double>()); vector->data<double>());
} }
template struct RowwiseSum<platform::CUDADeviceContext, float>;
// template struct RowwiseSum<platform::CUDADeviceContext, double>;
// TODO(zcd): Following ColwiseSum format, need to confirm.
// The RowwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
// and only failed for this case. So reimplemented it.
template <>
void RowwiseSum<platform::CUDADeviceContext, double>::operator()(
const platform::CUDADeviceContext& context, const framework::Tensor& input,
framework::Tensor* vector) {
auto in_dims = input.dims();
auto size = input.numel() / in_dims[0];
PADDLE_ENFORCE_EQ(vector->numel(), in_dims[0]);
framework::Tensor one;
one.mutable_data<double>({size}, context.GetPlace());
SetConstant<platform::CUDADeviceContext, double> set;
set(context, &one, static_cast<double>(1.0));
gemv<platform::CUDADeviceContext, double>(
context, true, static_cast<int>(in_dims[1]), static_cast<int>(in_dims[0]),
1.0, one.data<double>(), input.data<double>(), 0.0,
vector->data<double>());
}
template struct RowwiseMean<platform::CUDADeviceContext, float>;
template struct RowwiseMean<platform::CUDADeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -128,6 +128,18 @@ struct ColwiseSum { ...@@ -128,6 +128,18 @@ struct ColwiseSum {
framework::Tensor* vec); framework::Tensor* vec);
}; };
template <typename DeviceContext, typename T>
struct RowwiseSum {
void operator()(const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* vec);
};
template <typename DeviceContext, typename T>
struct RowwiseMean {
void operator()(const DeviceContext& context, const framework::Tensor& input,
framework::Tensor* vec);
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -87,6 +87,88 @@ class ColwiseSum<platform::CPUDeviceContext, T> { ...@@ -87,6 +87,88 @@ class ColwiseSum<platform::CPUDeviceContext, T> {
} }
}; };
template <typename DeviceContext, typename T>
void RowwiseMean<DeviceContext, T>::operator()(const DeviceContext& context,
const framework::Tensor& input,
framework::Tensor* out) {
auto in_dims = input.dims();
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
PADDLE_ENFORCE_EQ(out->numel(), in_dims[0]);
auto in = framework::EigenMatrix<T>::From(input);
auto vec = framework::EigenVector<T>::Flatten(*out);
vec.device(*context.eigen_device()) = in.mean(Eigen::array<int, 1>({{1}}));
}
// TODO(zcd): Following ColwiseSum format, need to confirm.
// Specialize for CPU, since Eigen implement a general reduce. However,
// rowwise-sum can be easily implemented. General reduce has a huge overhead in
// CPU
template <typename T>
class RowwiseMean<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* out) {
auto& in_dims = input.dims();
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
auto height = in_dims[0];
auto size = in_dims[1];
PADDLE_ENFORCE_EQ(out->numel(), height);
auto inv_size = 1.0 / size;
T* out_buf = out->mutable_data<T>(out->place());
const T* in_buf = input.data<T>();
for (size_t i = 0; i < static_cast<size_t>(height); ++i) {
T sum = 0;
for (size_t j = 0; j < static_cast<size_t>(size); ++j) {
sum += in_buf[i * size + j];
}
out_buf[i] = sum * inv_size;
}
}
};
template <typename DeviceContext, typename T>
void RowwiseSum<DeviceContext, T>::operator()(const DeviceContext& context,
const framework::Tensor& input,
framework::Tensor* out) {
auto in_dims = input.dims();
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
PADDLE_ENFORCE_EQ(out->numel(), in_dims[0]);
auto in = framework::EigenMatrix<T>::From(input);
auto vec = framework::EigenVector<T>::Flatten(*out);
vec.device(*context.eigen_device()) = in.sum(Eigen::array<int, 1>({{1}}));
}
// TODO(zcd): Following ColwiseSum format, need to confirm.
// Specialize for CPU, since Eigen implement a general reduce. However,
// rowwise-sum can be easily implemented. General reduce has a huge overhead in
// CPU
template <typename T>
class RowwiseSum<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, framework::Tensor* out) {
auto& in_dims = input.dims();
PADDLE_ENFORCE_EQ(in_dims.size(), 2U);
auto height = in_dims[0];
auto size = in_dims[1];
PADDLE_ENFORCE_EQ(out->numel(), size);
T* out_buf = out->mutable_data<T>(out->place());
const T* in_buf = input.data<T>();
for (size_t i = 0; i < static_cast<size_t>(height); ++i) {
T sum = 0;
for (size_t j = 0; j < static_cast<size_t>(size); ++j) {
sum += in_buf[i * size + j];
}
out_buf[i] = sum;
}
}
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册