提交 314e150f 编写于 作者: H Haihao Shen

Merge branch 'public-dev' into develop

/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -11,14 +11,151 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,14 +11,151 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include <algorithm>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/norm_op.h" #include "paddle/fluid/operators/norm_op.h"
namespace paddle {
namespace operators {
__device__ __forceinline__ float square_root(float x) { return sqrtf(x); }
__device__ __forceinline__ double square_root(double x) { return sqrt(x); }
template <typename T, int BlockDim>
__global__ void Normalize(const T* x, const int pre,
const int axis_n, // dim in axis
const int post, const T eps, T* y, T* out_norm) {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
int base = (i / post) * post * axis_n + (i % post);
T sum = 0.0;
__shared__ T norm;
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const T x_ij = x[base + j * post];
sum += x_ij * x_ij;
}
T reduce_result = BlockReduce(temp_storage).Sum(sum);
if (threadIdx.x == 0) {
norm = square_root(reduce_result + eps);
out_norm[i] = norm;
}
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
const int index = base + j * post;
y[index] = x[index] / norm;
}
}
}
template <typename DeviceContext, typename T>
class NormCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_x = ctx.Input<framework::Tensor>("X");
auto* out_y = ctx.Output<framework::Tensor>("Out");
auto* out_norm = ctx.Output<framework::Tensor>("Norm");
const T* x = in_x->data<T>();
T* y = out_y->mutable_data<T>(ctx.GetPlace());
T* norm = out_norm->mutable_data<T>(ctx.GetPlace());
auto xdim = in_x->dims();
auto ndim = out_norm->dims();
int axis = ctx.Attr<int>("axis");
T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post);
auto& dev_ctx = ctx.cuda_device_context();
const int block = 512;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(max_blocks, pre * post);
Normalize<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
eps, y, norm);
}
};
template <typename T, int BlockDim>
__global__ void NormalizeGradient(const T* x, const T* x_norm, const T* y_grad,
const int pre, const int axis_n,
const int post, T* x_grad) {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage_sum;
int num = pre * post;
for (int i = blockIdx.x; i < num; i += gridDim.x) {
T sum = 0.0;
__shared__ T row_sum;
__shared__ T row_sqrt_norm;
__shared__ T row_norm;
auto base = (i / post) * post * axis_n + (i % post);
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
sum += x[index] * y_grad[index];
}
T reduce_result = BlockReduce(temp_storage_sum).Sum(sum);
if (threadIdx.x == 0) {
row_sum = reduce_result;
row_sqrt_norm = x_norm[i];
row_norm = row_sqrt_norm * row_sqrt_norm;
}
__syncthreads();
for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
int index = base + j * post;
const T x_ij = x[index];
const T dy_ij = y_grad[index];
x_grad[index] = (dy_ij - x_ij * row_sum / row_norm) / row_sqrt_norm;
}
}
}
template <typename DeviceContext, typename T, typename AttrType = T>
class NormGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_x = ctx.Input<framework::Tensor>("X");
auto* in_norm = ctx.Input<framework::Tensor>("Norm");
auto* in_dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out_dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
T* dx = out_dx->mutable_data<T>(ctx.GetPlace());
const T* x = in_x->data<T>();
const T* x_norm = in_norm->data<T>();
const T* dy = in_dy->data<T>();
auto xdim = in_x->dims();
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post);
auto& dev_ctx = ctx.cuda_device_context();
const int block = 512;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(max_blocks, pre * post);
NormalizeGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
x, x_norm, dy, pre, n, post, dx);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext; using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(norm, ops::NormKernel<CUDA, float>, REGISTER_OP_CUDA_KERNEL(norm, ops::NormCUDAKernel<CUDA, float>,
ops::NormKernel<CUDA, double>); ops::NormCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(norm_grad, ops::NormGradKernel<CUDA, float>, REGISTER_OP_CUDA_KERNEL(norm_grad, ops::NormGradCUDAKernel<CUDA, float>,
ops::NormGradKernel<CUDA, double>); ops::NormGradCUDAKernel<CUDA, double>);
...@@ -65,14 +65,17 @@ class NormKernel : public framework::OpKernel<T> { ...@@ -65,14 +65,17 @@ class NormKernel : public framework::OpKernel<T> {
Eigen::DSizes<int, 1> rdim(1); Eigen::DSizes<int, 1> rdim(1);
// y = x / sqrt((sum(x * x) + epsilon)) // y = x / sqrt((sum(x * x) + epsilon))
// norm = sqrt(sum(x * x) + epsilon) // norm = sqrt(sum(x * x) + epsilon)
auto sum = x.pow(2).sum(rdim) + eps; auto x2 = x * x;
auto sum = x2.sum(rdim) + eps;
norm.device(*place) = sum.sqrt(); norm.device(*place) = sum.sqrt();
// y = x / norm // y = x / norm
Eigen::DSizes<int, 3> rshape(pre, 1, post); Eigen::DSizes<int, 3> rshape(pre, 1, post);
Eigen::DSizes<int, 3> bcast(1, n, 1); Eigen::DSizes<int, 3> bcast(1, n, 1);
y.device(*place) = x / norm.reshape(rshape).broadcast(bcast); y.device(*place) = x / norm.reshape(rshape).broadcast(bcast);
} }
}; };
template <typename DeviceContext, typename T, typename AttrType = T> template <typename DeviceContext, typename T, typename AttrType = T>
class NormGradKernel : public framework::OpKernel<T> { class NormGradKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -63,5 +63,27 @@ class TestNormOp3(TestNormOp): ...@@ -63,5 +63,27 @@ class TestNormOp3(TestNormOp):
self.epsilon = 1e-8 self.epsilon = 1e-8
class TestNormOp4(TestNormOp):
def init_test_case(self):
self.shape = [128, 1024, 14, 14]
self.axis = 2
self.epsilon = 1e-8
def test_check_grad(self):
# since the gradient check is very slow in large shape, so skip check_grad
pass
class TestNormOp5(TestNormOp):
def init_test_case(self):
self.shape = [2048, 2048]
self.axis = 1
self.epsilon = 1e-8
def test_check_grad(self):
# since the gradient check is very slow in large shape, so skip check_grad
pass
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册