提交 0cc63549 编写于 作者: C chengduoZH

merge develop

......@@ -244,7 +244,7 @@ __device__ __forceinline__ void blockReduce(Pair* shTopK,
if (--beamSize == 0) break;
__syncthreads();
// temporary solution
// NOTE(zcd): temporary solution
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNMemDesc;
using mkldnn::memory;
template <typename T>
using EigenArrayMap =
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using ConstEigenArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename T>
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
namespace {
template <typename T>
struct bn_type_traits {
using op_type = T;
using op_desc = typename op_type::desc;
using op_prim = typename op_type::primitive_desc;
};
template <typename T, typename Container>
void copy_to_weights(T scale_begin, T scale_end, T shift_begin, T shift_end,
Container *c) {
auto it = std::begin(*c);
std::copy(scale_begin, scale_end, std::inserter(*c, it));
std::copy(
shift_begin, shift_end,
std::inserter(*c, std::next(it, std::distance(scale_begin, scale_end))));
}
template <typename Op, typename... Args>
void run_batch_norm_op(Args &&... args) {
Op batch_norm_op{args...};
std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(batch_norm_op);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
template <typename T>
inline void *cast_const_to_void(const T *t) {
return static_cast<void *>(const_cast<T *>(t));
}
} // namespace
template <typename T>
class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto data_layout_str = ctx.Attr<std::string>("data_layout");
auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE(data_layout == framework::DataLayout::kNCHW,
"MKLDNN batch normalization handles only NCHW data layout");
const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
const auto *x = ctx.Input<Tensor>("X");
const auto *mean = ctx.Input<Tensor>("Mean");
const auto *variance = ctx.Input<Tensor>("Variance");
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
auto *y = ctx.Output<Tensor>("Y");
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
auto *batch_mean = ctx.Output<Tensor>("SavedMean");
auto *batch_variance = ctx.Output<Tensor>("SavedVariance");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *shift = ctx.Input<Tensor>("Bias");
y->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<T>(ctx.GetPlace());
variance_out->mutable_data<T>(ctx.GetPlace());
if (!is_test) {
batch_mean->mutable_data<T>(ctx.GetPlace());
batch_variance->mutable_data<T>(ctx.GetPlace());
}
auto propagation = is_test == true ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
auto dims = paddle::framework::vectorize2int(x->dims());
auto src_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
auto dst_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
auto src_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
auto dst_pd = mkldnn::memory::primitive_desc{dst_md, mkldnn_engine};
auto src = mkldnn::memory{src_pd, cast_const_to_void(x->data<T>())};
auto dst = mkldnn::memory{dst_pd, y->data<T>()};
unsigned flags = mkldnn::use_scale_shift;
if (is_test) flags |= mkldnn::use_global_stats;
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
auto batch_norm_fwd_desc =
bn_fwd_types::op_desc{propagation, src_md, epsilon, flags};
auto batch_norm_fwd_pd =
bn_fwd_types::op_prim{batch_norm_fwd_desc, mkldnn_engine};
const unsigned int ic = dims[1];
// MKLDNN requires a single piece of memory for scale and shift/bias data
const size_t scaleshift_size = 2 * ic;
std::vector<T> scaleshift_data;
scaleshift_data.reserve(scaleshift_size);
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(),
shift->data<T>() + ic, &scaleshift_data);
auto scaleshift_memory = mkldnn::memory{
batch_norm_fwd_pd.weights_primitive_desc(), scaleshift_data.data()};
if (is_test) {
auto mean_memory = mkldnn::memory{batch_norm_fwd_pd.mean_primitive_desc(),
cast_const_to_void(mean->data<T>())};
auto variance_memory =
mkldnn::memory{batch_norm_fwd_pd.variance_primitive_desc(),
cast_const_to_void(variance->data<T>())};
run_batch_norm_op<typename bn_fwd_types::op_type>(
batch_norm_fwd_pd, src, (const mkldnn::primitive::at &)mean_memory,
(const mkldnn::primitive::at &)variance_memory, scaleshift_memory,
dst);
} else {
auto mean_memory =
mkldnn::memory{batch_norm_fwd_pd.mean_primitive_desc(),
cast_const_to_void(batch_mean->data<T>())};
auto variance_memory =
mkldnn::memory{batch_norm_fwd_pd.variance_primitive_desc(),
cast_const_to_void(batch_variance->data<T>())};
run_batch_norm_op<bn_fwd_types::op_type>(batch_norm_fwd_pd, src,
scaleshift_memory, dst,
mean_memory, variance_memory);
}
if (!is_test) {
const unsigned int in = dims[0];
const unsigned int sample_size = x->numel() / in / ic;
// saved_xx is use just in this batch of data
EigenVectorArrayMap<T> saved_mean_e(
batch_mean->mutable_data<T>(ctx.GetPlace()), ic);
EigenVectorArrayMap<T> saved_variance_e(
batch_variance->mutable_data<T>(ctx.GetPlace()), ic);
saved_mean_e.setZero();
saved_variance_e.setZero();
const unsigned int x_arr_size = in * ic;
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, x_arr_size);
for (unsigned int nc = 0; nc < x_arr_size; ++nc) {
saved_mean_e(nc % ic) += x_arr.col(nc).sum();
}
saved_mean_e /= in * sample_size;
for (unsigned int nc = 0; nc < x_arr_size; ++nc) {
saved_variance_e(nc % ic) +=
(x_arr.col(nc) - saved_mean_e(nc % ic)).matrix().squaredNorm();
}
saved_variance_e /= in * sample_size;
ConstEigenVectorArrayMap<T> mean_arr{mean->data<T>(), ic};
ConstEigenVectorArrayMap<T> variance_arr{variance->data<T>(), ic};
EigenVectorArrayMap<T> running_mean_arr(
mean_out->mutable_data<T>(ctx.GetPlace()), ic);
EigenVectorArrayMap<T> running_var_arr(
variance_out->mutable_data<T>(ctx.GetPlace()), ic);
auto one_minus_momentum = 1. - momentum;
running_mean_arr =
mean_arr * momentum + saved_mean_e * one_minus_momentum;
running_var_arr =
variance_arr * momentum + saved_variance_e * one_minus_momentum;
}
}
};
template <typename T>
class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto data_layout_str = ctx.Attr<std::string>("data_layout");
auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE(data_layout == framework::DataLayout::kNCHW,
"MKLDNN batch normalization handles only NCHW data layout");
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
auto mkldnn_engine = dev_ctx.GetEngine();
const float epsilon = ctx.Attr<float>("epsilon");
const auto *x = ctx.Input<Tensor>("X");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *shift = ctx.Input<Tensor>("Bias");
const auto *batch_mean = ctx.Input<Tensor>("SavedMean");
const auto *batch_variance = ctx.Input<Tensor>("SavedVariance");
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *diff_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias"));
diff_x->mutable_data<T>(ctx.GetPlace());
diff_scale->mutable_data<T>(ctx.GetPlace());
diff_shift->mutable_data<T>(ctx.GetPlace());
auto dims = paddle::framework::vectorize2int(x->dims());
unsigned flags = mkldnn::use_scale_shift | !mkldnn::use_global_stats;
auto src_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
auto dst_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
auto diff_src_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
auto diff_dst_md =
MKLDNNMemDesc(dims, memory::data_type::f32, memory::format::nchw);
using bn_bwd_types = bn_type_traits<mkldnn::batch_normalization_backward>;
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
auto batch_norm_fwd_desc = bn_fwd_types::op_desc{
mkldnn::prop_kind::forward_training, src_md, epsilon, flags};
auto batch_norm_fwd_pd =
bn_fwd_types::op_prim{batch_norm_fwd_desc, mkldnn_engine};
auto batch_norm_bwd_desc = bn_bwd_types::op_desc{
mkldnn::prop_kind::backward, diff_dst_md, dst_md, epsilon, flags};
auto batch_norm_bwd_pd = bn_bwd_types::op_prim{
batch_norm_bwd_desc, mkldnn_engine, batch_norm_fwd_pd};
auto src = mkldnn::memory{{src_md, mkldnn_engine},
cast_const_to_void(x->data<T>())};
auto mean = mkldnn::memory{batch_norm_bwd_pd.mean_primitive_desc(),
cast_const_to_void(batch_mean->data<T>())};
auto variance =
mkldnn::memory{batch_norm_bwd_pd.variance_primitive_desc(),
cast_const_to_void(batch_variance->data<T>())};
auto diff_dst = mkldnn::memory{{diff_dst_md, mkldnn_engine},
cast_const_to_void(diff_y->data<T>())};
const unsigned int ic = dims[1];
const size_t scaleshift_size = 2 * ic;
std::vector<T> scaleshift_data;
scaleshift_data.reserve(scaleshift_size);
copy_to_weights(scale->data<T>(), scale->data<T>() + ic, shift->data<T>(),
shift->data<T>() + ic, &scaleshift_data);
auto scaleshift_memory = mkldnn::memory{
batch_norm_bwd_pd.weights_primitive_desc(), scaleshift_data.data()};
std::vector<T> diff_scaleshift_data;
diff_scaleshift_data.reserve(scaleshift_size);
copy_to_weights(diff_scale->data<T>(), diff_scale->data<T>() + ic,
diff_shift->data<T>(), diff_shift->data<T>() + ic,
&diff_scaleshift_data);
auto diff_scaleshift_memory =
mkldnn::memory{batch_norm_bwd_pd.diff_weights_primitive_desc(),
diff_scaleshift_data.data()};
auto diff_src = mkldnn::memory{{diff_src_md, mkldnn_engine},
static_cast<void *>(diff_x->data<T>())};
run_batch_norm_op<bn_bwd_types::op_type>(
batch_norm_bwd_pd, src, mean, variance, diff_dst, scaleshift_memory,
diff_src, diff_scaleshift_memory);
auto it = std::begin(diff_scaleshift_data);
std::copy(it, std::next(it, ic), diff_scale->data<T>());
std::copy(std::next(it, ic), std::end(diff_scaleshift_data),
diff_shift->data<T>());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(batch_norm, MKLDNN, paddle::platform::CPUPlace,
ops::BatchNormMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(batch_norm_grad, MKLDNN, paddle::platform::CPUPlace,
ops::BatchNormMKLDNNGradOpKernel<float>);
......@@ -15,6 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/batch_norm_op.h"
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
......@@ -106,7 +109,18 @@ class BatchNormOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(
ctx.Input<Tensor>("Variance")->type()),
"Variance input should be of float type");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
}
#endif
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library_);
}
};
......@@ -151,6 +165,9 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
"Variance of the current mini batch, "
"will apply to output when training")
.AsIntermediate();
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Batch Normalization.
......@@ -349,8 +366,19 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
return framework::OpKernelType(framework::ToDataType(t->type()),
ctx.GetPlace());
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
}
#endif
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout, library_);
}
};
......@@ -474,6 +502,7 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetInput("Scale", Input("Scale"));
op->SetInput("Bias", Input("Bias"));
op->SetInput("SavedMean", Output("SavedMean"));
op->SetInput("SavedVariance", Output("SavedVariance"));
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#ifdef __NVCC__
#include <cuda.h>
#include <thrust/iterator/iterator_adaptor.h>
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
#endif
......@@ -336,43 +337,6 @@ static void ElemwiseGradBroadcast1CPU(const T* x, const T* y, const T* out,
}
#ifdef __NVCC__
template <typename T>
__device__ T reduceSum(T val, int tid, int len) {
// NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
const int warpSize = 32;
__shared__ T shm[warpSize];
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, tid < len);
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::__shfl_down_sync(mask, val, offset);
if (tid < warpSize) shm[tid] = 0;
__syncthreads();
if (tid % warpSize == 0) {
shm[tid / warpSize] = val;
}
__syncthreads();
CREATE_SHFL_MASK(mask, tid < warpSize);
if (tid < warpSize) {
val = shm[tid];
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::__shfl_down_sync(mask, val, offset);
}
return val;
}
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void ElemwiseGradBroadcast1CUDAKernel(
const T* x, const T* y, const T* out, const T* dout, int h, int w,
......@@ -395,7 +359,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
if (dy) {
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = reduceSum(val, tid, h);
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
......@@ -472,7 +436,7 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
if (dy) {
int h = pre * post;
h = h > ELEMWISE_MAX_BLOCK_DIM ? ELEMWISE_MAX_BLOCK_DIM : h;
val = reduceSum(val, tid, h);
val = paddle::platform::reduceSum(val, tid, h);
if (threadIdx.x == 0) {
dy[j] = val;
}
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
......@@ -30,66 +31,22 @@ __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
}
}
template <typename T>
__device__ __forceinline__ T sum_single_warp(T val) {
val += platform::__shfl_down_sync(0, val, 16);
val += platform::__shfl_down_sync(0, val, 8);
val += platform::__shfl_down_sync(0, val, 4);
val += platform::__shfl_down_sync(0, val, 2);
val += platform::__shfl_down_sync(0, val, 1);
return val;
}
// CUDA do not support dynamic arrary in template
// https://stackoverflow.com/questions/20497209
template <typename T>
struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ T* GetPointer() { return NULL; }
};
template <>
struct SharedMemory<float> {
__device__ float* GetPointer() {
extern __shared__ float s_float[];
return s_float;
}
};
template <>
struct SharedMemory<double> {
__device__ double* GetPointer() {
extern __shared__ double s_double[];
return s_double;
}
};
template <typename T>
__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
const int class_num) {
int tid = threadIdx.x;
SharedMemory<T> d_sum_shared;
T* d_sum = d_sum_shared.GetPointer();
d_sum[tid] = 0;
T val = 0;
int cur_idx = tid;
int next_idx = blockIdx.x * class_num + tid;
while (cur_idx < class_num) {
d_sum[tid] +=
math::TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
next_idx += blockDim.x;
cur_idx += blockDim.x;
int idx = blockIdx.x * class_num + tid;
int end = blockIdx.x * class_num + class_num;
for (; idx < end; idx += blockDim.x) {
val += math::TolerableValue<T>()(std::log(X[idx])) * label[idx];
}
__syncthreads();
for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
if (tid < stride) d_sum[tid] += d_sum[tid + stride];
__syncthreads();
val = paddle::platform::reduceSum(val, tid, blockDim.x);
if (threadIdx.x == 0) {
Y[blockIdx.x] = -val;
}
T val = d_sum[tid];
val = sum_single_warp<T>(val);
if (tid == 0) Y[blockIdx.x] = -val;
}
} // namespace
......@@ -113,9 +70,7 @@ class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
? 512
: pow(2, static_cast<int>(std::log2(class_num)));
SoftCrossEntropyKernel<T><<<
batch_size, block, block * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
SoftCrossEntropyKernel<T><<<batch_size, block, 0, ctx.stream()>>>(
loss_data, prob_data, label_data, class_num);
} else {
const int64_t* label_data = labels->data<int64_t>();
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/row_conv_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/cuda_device_function.h"
namespace paddle {
namespace operators {
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/cuda_device_function.h"
namespace paddle {
namespace operators {
......@@ -236,12 +236,13 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
sh_topk[tid] = topk[*beam];
}
}
// temporary solution
// NOTE(zcd): temporary solution
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true);
if (maxid[0] / 32 == warp) {
if (__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength) break;
if (platform::__shfl_sync(mask, *beam, (maxid[0]) % 32, 32) == MaxLength)
break;
}
}
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
namespace paddle {
namespace platform {
// __shfl_down and __shfl have been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
template <typename T>
__forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line,
int width) {
return __shfl(val, src_line, width);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
template <typename T>
__device__ T reduceSum(T val, int tid, int len) {
// NOTE(zcd): The warp size should be taken from the
// parameters of the GPU but not specified as 32 simply.
// To make the reduceSum more efficiently,
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
const int warpSize = 32;
__shared__ T shm[warpSize];
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, tid < len);
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::__shfl_down_sync(mask, val, offset);
if (tid < warpSize) shm[tid] = 0;
if (tid % warpSize == 0) {
shm[tid / warpSize] = val;
}
__syncthreads();
CREATE_SHFL_MASK(mask, tid < warpSize);
if (tid < warpSize) {
val = shm[tid];
for (int offset = warpSize / 2; offset > 0; offset /= 2)
val += platform::__shfl_down_sync(mask, val, offset);
}
return val;
}
} // namespace platform
} // namespace paddle
......@@ -65,26 +65,5 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
return __longlong_as_double(old);
}
#endif
// __shfl_down has been deprecated as of CUDA 9.0.
#if CUDA_VERSION < 9000
template <typename T>
__forceinline__ __device__ T __shfl_down_sync(unsigned, T val, int delta) {
return __shfl_down(val, delta);
}
template <typename T>
__forceinline__ __device__ T __shfl_sync(unsigned, T val, int src_line,
int width) {
return __shfl(val, src_line, width);
}
#define CREATE_SHFL_MASK(mask, predicate) mask = 0u;
#else
#define FULL_WARP_MASK 0xFFFFFFFF
#define CREATE_SHFL_MASK(mask, predicate) \
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
} // namespace platform
} // namespace paddle
......@@ -40,6 +40,7 @@ function print_usage() {
${BLUE}capi${NONE}: generate paddle CAPI package
${BLUE}fluid_inference_lib${NONE}: deploy fluid inference library
${BLUE}check_style${NONE}: run code style check
${BLUE}cicheck${NONE}: run CI tasks
"
}
......@@ -453,6 +454,8 @@ function gen_capi_package() {
}
function gen_fluid_inference_lib() {
mkdir -p ${PADDLE_ROOT}/build
cd ${PADDLE_ROOT}/build
if [ ${WITH_C_API:-OFF} == "OFF" ] ; then
cat <<EOF
========================================
......@@ -503,6 +506,13 @@ function main() {
check_style)
check_style
;;
cicheck)
cmake_gen ${PYTHON_ABI:-""}
build
run_test
gen_capi_package
gen_fluid_inference_lib
;;
*)
print_usage
exit 0
......
......@@ -21,8 +21,7 @@ import executor
from executor import *
import trainer
from trainer import Trainer
from trainer import Event
from trainer import *
import inferencer
from inferencer import Inferencer
......
......@@ -50,8 +50,6 @@ def data(name,
dtype(int|float): The type of data : float32, float_16, int etc
type(VarType): The output type. By default it is LOD_TENSOR.
lod_level(int): The LoD Level. 0 means the input data is not a sequence.
main_program(Program): Name of the main program that calls this
startup_program(Program): Name of the startup program
stop_gradient(bool): A boolean that mentions whether gradient should flow.
Returns:
......@@ -74,13 +72,15 @@ def data(name,
if append_batch_size:
shape = [-1] + shape # append batch size as -1
return helper.create_global_variable(
data_var = helper.create_global_variable(
name=name,
shape=shape,
dtype=dtype,
type=type,
stop_gradient=stop_gradient,
lod_level=lod_level)
data_var.is_data = True
return data_var
class BlockGuardServ(BlockGuard):
......
......@@ -1496,6 +1496,7 @@ def batch_norm(input,
bias_attr=None,
data_layout='NCHW',
in_place=False,
use_mkldnn=False,
name=None,
moving_mean_name=None,
moving_variance_name=None,
......@@ -1574,9 +1575,12 @@ def batch_norm(input,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
},
attrs={"momentum": momentum,
"epsilon": epsilon,
"is_test": is_test})
attrs={
"momentum": momentum,
"epsilon": epsilon,
"is_test": is_test,
"use_mkldnn": use_mkldnn
})
return helper.append_activation(batch_norm_out)
......
......@@ -28,7 +28,8 @@ from contextlib import contextmanager
__all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad',
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'Adadelta', 'ModelAverage'
'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'Adadelta', 'ModelAverage',
'Optimizer'
]
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import numpy as np
import math
import sys
from functools import partial
PASS_NUM = 100
EMBED_SIZE = 32
HIDDEN_SIZE = 256
N = 5
BATCH_SIZE = 32
def create_random_lodtensor(lod, place, low, high):
# The range of data elements is [low, high]
data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64")
res = fluid.LoDTensor()
res.set(data, place)
res.set_lod([lod])
return res
word_dict = paddle.dataset.imikolov.build_dict()
dict_size = len(word_dict)
def inference_network(is_sparse):
first_word = fluid.layers.data(name='firstw', shape=[1], dtype='int64')
second_word = fluid.layers.data(name='secondw', shape=[1], dtype='int64')
third_word = fluid.layers.data(name='thirdw', shape=[1], dtype='int64')
forth_word = fluid.layers.data(name='forthw', shape=[1], dtype='int64')
embed_first = fluid.layers.embedding(
input=first_word,
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=is_sparse,
param_attr='shared_w')
embed_second = fluid.layers.embedding(
input=second_word,
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=is_sparse,
param_attr='shared_w')
embed_third = fluid.layers.embedding(
input=third_word,
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=is_sparse,
param_attr='shared_w')
embed_forth = fluid.layers.embedding(
input=forth_word,
size=[dict_size, EMBED_SIZE],
dtype='float32',
is_sparse=is_sparse,
param_attr='shared_w')
concat_embed = fluid.layers.concat(
input=[embed_first, embed_second, embed_third, embed_forth], axis=1)
hidden1 = fluid.layers.fc(input=concat_embed,
size=HIDDEN_SIZE,
act='sigmoid')
predict_word = fluid.layers.fc(input=hidden1, size=dict_size, act='softmax')
return predict_word
def train_network(is_sparse):
next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64')
predict_word = inference_network(is_sparse)
cost = fluid.layers.cross_entropy(input=predict_word, label=next_word)
avg_cost = fluid.layers.mean(cost)
return avg_cost
def train(use_cuda, is_sparse, save_path):
train_reader = paddle.batch(
paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
def event_handler(event):
print type(event)
if isinstance(event, fluid.EndEpochEvent):
avg_cost = trainer.test(reader=paddle.dataset.imikolov.test(
word_dict, N))
if avg_cost < 5.0:
trainer.params.save(save_path)
return
if math.isnan(avg_cost):
sys.exit("got NaN loss, training failed.")
trainer = fluid.Trainer(
partial(train_network, is_sparse),
fluid.optimizer.SGD(learning_rate=0.001),
place=place)
trainer.train(
reader=train_reader, num_epochs=100, event_handler=event_handler)
def infer(use_cuda, save_path):
params = fluid.Params(save_path)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
inferencer = fluid.Inferencer(inference_network, params, place=place)
lod = [0, 1]
first_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
second_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
third_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
fourth_word = create_random_lodtensor(lod, place, low=0, high=dict_size - 1)
result = inferencer.infer({
'firstw': first_word,
'secondw': second_word,
'thirdw': third_word,
'forthw': fourth_word
})
print(result)
def main(use_cuda, is_sparse):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
save_path = "word2vec.inference.model"
train(use_cuda, is_sparse, save_path)
infer(use_cuda, save_path)
if __name__ == '__main__':
for use_cuda in (False, True):
for is_sparse in (False, True):
main(use_cuda=use_cuda, is_sparse=is_sparse)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from op_test import OpTest
from paddle.fluid.framework import grad_var_name
from test_batch_norm_op import TestBatchNormOpInference, TestBatchNormOpTraining, _reference_training, _reference_grad
class TestMKLDNNBatchNormOpTraining(TestBatchNormOpTraining):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_formats = ["NCHW"]
def ref_forward_backward(self, x, y_grad, scale, bias, mean, variance,
epsilon, momentum, shape, data_layout):
# run forward
y, saved_mean, saved_variance = _reference_training(
x, scale, bias, epsilon, data_layout)
mean_out = saved_mean * (1. - momentum) + momentum * mean
variance_out = saved_variance * (1. - momentum) + momentum * variance
# run backward
x_grad, scale_grad, bias_grad = _reference_grad(
x, y_grad, scale, saved_mean, saved_variance, epsilon, data_layout)
return y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad
class TestMKLDNNBatchNormOpInference(TestBatchNormOpInference):
def init_kernel_type(self):
self.use_mkldnn = True
def test_check_output(self):
place = core.CPUPlace()
data_format = "NCHW"
self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
if __name__ == '__main__':
unittest.main()
......@@ -158,6 +158,8 @@ def set_output_grad(scope, outputs, place, feed_dict=None):
class TestBatchNormOpInference(unittest.TestCase):
def setUp(self):
self.dtype = np.float32
self.use_mkldnn = False
self.init_kernel_type()
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
......@@ -230,6 +232,7 @@ class TestBatchNormOpInference(unittest.TestCase):
# attrs
is_test=True,
data_layout=data_layout,
use_mkldnn=self.use_mkldnn,
epsilon=epsilon)
batch_norm_op.run(scope, place)
......@@ -254,10 +257,15 @@ class TestBatchNormOpInference(unittest.TestCase):
[2, 3, 4, 5])
self.check_with_place(place, data_format, self.dtype, [2, 3])
def init_kernel_type(self):
pass
class TestFP16BatchNormOpInference(TestBatchNormOpInference):
def setUp(self):
self.dtype = np.float16
self.use_mkldnn = False
self.init_kernel_type()
def test_check_output(self):
places = []
......@@ -274,9 +282,28 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference):
class TestBatchNormOpTraining(unittest.TestCase):
def setUp(self):
self.use_mkldnn = False
self.data_formats = ["NCHW", "NHWC"]
self.init_kernel_type()
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
np.allclose(np.array(tensor), np_array, atol=atol)
def ref_forward_backward(self, x, y_grad, scale, bias, mean, variance,
epsilon, momentum, shape, data_layout):
# run forward
y, saved_mean, var_ref = _reference_training(x, scale, bias, epsilon,
data_layout)
mean_out = saved_mean * (1. - momentum) + momentum * mean
variance_out = var_ref * (1. - momentum) + momentum * variance
saved_variance = 1. / np.sqrt(var_ref + epsilon)
# run backward
x_grad, scale_grad, bias_grad = _reference_grad(
x, y_grad, scale, saved_mean, var_ref, epsilon, data_layout)
return y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad
def test_forward_backward(self):
def test_with_place(place, data_layout, shape):
# attr
......@@ -295,16 +322,11 @@ class TestBatchNormOpTraining(unittest.TestCase):
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
# run forward
y, saved_mean, var_ref = _reference_training(x, scale, bias,
epsilon, data_layout)
mean_out = saved_mean * (1. - momentum) + momentum * mean
variance_out = var_ref * (1. - momentum) + momentum * variance
saved_variance = 1. / np.sqrt(var_ref + epsilon)
# run backward
y_grad = np.random.random_sample(shape).astype(np.float32)
x_grad, scale_grad, bias_grad = _reference_grad(
x, y_grad, scale, saved_mean, var_ref, epsilon, data_layout)
y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad = self.ref_forward_backward(
x, y_grad, scale, bias, mean, variance, epsilon, momentum,
shape, data_layout)
var_dict = locals()
var_dict['y@GRAD'] = y_grad
......@@ -344,7 +366,8 @@ class TestBatchNormOpTraining(unittest.TestCase):
"momentum": momentum,
"epsilon": epsilon,
"is_test": False,
"data_layout": data_layout
"data_layout": data_layout,
"use_mkldnn": self.use_mkldnn
})
block.create_var(name='y@GRAD', dtype='float32', shape=y.shape)
......@@ -387,13 +410,17 @@ class TestBatchNormOpTraining(unittest.TestCase):
print "op test forward passed: ", str(place), data_layout
places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
places.append(core.CUDAPlace(0))
for place in places:
for data_format in ["NCHW", "NHWC"]:
for data_format in self.data_formats:
test_with_place(place, data_format, [2, 3, 4, 5])
def init_kernel_type(self):
pass
if __name__ == '__main__':
unittest.main()
......@@ -12,44 +12,200 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import core
import framework
import executor
import data_feeder
import contextlib
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
import optimizer as opt_module
__all__ = [
'Event',
'Trainer',
'BeginEpochEvent',
'EndEpochEvent',
'BeginStepEvent',
'EndStepEvent',
]
class Event(object):
BEGIN_EPOCH = 0
END_EPOCH = 1
BEGIN_STEP = 2
END_STEP = 3
class BeginEpochEvent(object):
def __init__(self, epoch_id):
self.epoch = epoch_id
class EndEpochEvent(object):
def __init__(self, epoch_id):
self.epoch = epoch_id
def __init__(self):
self.step = 0
self.epoch = 0
self.type = Event.BEGIN_EPOCH
class BeginStepEvent(object):
def __init__(self, epoch_id, step_id):
self.epoch = epoch_id
self.step = step_id
class EndStepEvent(object):
def __init__(self, epoch_id, step_id):
self.epoch = epoch_id
self.step = step_id
class Trainer(object):
"""
Args:
network_func(callable): A function which will return loss. The loss must be a scaler.
optimizer(optimizer.Optimizer): The optimizer should be an instance of Optimizer
params:
place: The device place of this trainer.
"""
def __init__(self, network_func, optimizer, params=None, place=None):
# 1. we need to generate a framework.Program by calling
# network_func. Reference: fluid.program_guard in
# test_word2vec.py
self.scope = self._get_scope_from_params(params)
self.startup_program = framework.Program()
self.train_program = framework.Program()
with framework.program_guard(self.train_program, self.startup_program):
loss = network_func()
if not isinstance(optimizer, opt_module.Optimizer):
raise TypeError(
"The optimizer should be an instance of Optimizer")
optimizer.minimize(loss)
self.place = Trainer._check_and_get_place(place)
# 2. move the default_main_program to self.program and run the
# default_startup program on an empty core.Scope()
# Run startup program
if params is None:
exe = executor.Executor(place)
exe.run(self.startup_program, scope=self.scope)
# 3. call self.params.add_vars with the initialized scope, it
# will add the new vars of the initialized scope into
# self.params.
self.network_func = network_func
self.optimizer = optimizer
self.params = params
self.place = place
# TODO(yuyang): This depends on parameters implementation.
# TODO(helin): support distributed training
def train(self, reader, num_epochs, event_handler):
pass
def train(self,
num_epochs,
event_handler,
reader=None,
parallel=False,
feed_order=None):
"""
Train the model.
Args:
num_epochs: The number of epoch. An epoch will process all data in reader
event_handler: The event handler. A function with type (ev:Event)->void
reader:
parallel: True if use multi-CPUs or multi-GPUs
feed_order: Feeding order of reader. None will following the defining
order in program
Returns:
"""
if parallel:
raise NotImplementedError(
"Parallel Executor version of trainer is not implemented")
self._train_by_executor(num_epochs, event_handler, reader, feed_order)
def test(self, reader):
pass
def _get_scope_from_params(self, params):
"""
Get Scope from parameter object.
Args:
params(Parameter|None): The parameter object instance. Could be None.
Returns: New scope if params is None. Or params.scope()
NOTE: This method is WIP. Not fully implemented.
"""
if params is None:
return core.Scope() # new scope when params is None
else:
raise NotImplementedError("Not implemented right now.")
@staticmethod
def _check_and_get_place(place):
"""
Check the type of place or get the default place
Args:
place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on.
Raises:
TypeError if the type mismatched.
Returns:
the original place if it is not None.
if fluid is compiled with CUDA, returns CUDAPlace(0) by default.
Otherwise returns CPUPlace by default.
"""
if place is None:
if core.is_compiled_with_cuda():
return core.CUDAPlace(0)
else:
return core.CPUPlace()
else:
if not isinstance(place, core.CUDAPlace) and not isinstance(
place, core.CPUPlace):
raise TypeError("Place should be either CUDAPlace or CPUPlace")
return place
@contextlib.contextmanager
def _prog_and_scope_guard(self):
with framework.program_guard(
main_program=self.train_program,
startup_program=self.startup_program):
with executor.scope_guard(self.scope):
yield
def _train_by_executor(self, num_epochs, event_handler, reader, feed_order):
"""
Train by Executor and single device.
Args:
num_epochs:
event_handler:
reader:
feed_order:
Returns:
"""
with self._prog_and_scope_guard():
exe = executor.Executor(self.place)
if feed_order is None:
feed_var_list = [
var
for var in self.train_program.global_block(
).vars.itervalues()
if hasattr(var, 'is_data') and var.is_data
]
else:
feed_var_list = [
self.train_program.global_block().var(var_name)
for var_name in feed_order
]
feeder = data_feeder.DataFeeder(
feed_list=feed_var_list, place=self.place)
for epoch_id in range(num_epochs):
event_handler(BeginEpochEvent(epoch_id))
for step_id, data in enumerate(reader()):
event_handler(BeginStepEvent(epoch_id, step_id))
exe.run(feed=feeder.feed(data), fetch_list=[])
event_handler(EndStepEvent(epoch_id, step_id))
event_handler(EndEpochEvent(epoch_id))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册