未验证 提交 767050d9 编写于 作者: Y Yiqun Liu 提交者: GitHub

Implement the grad and enhance the cache of norm_convolution fusion ops. (#36168)

上级 b3d2dc7b
...@@ -15,8 +15,10 @@ limitations under the License. */ ...@@ -15,8 +15,10 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <mutex>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "glog/logging.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -14,10 +14,8 @@ limitations under the License. */ ...@@ -14,10 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/platform/cudnn_desc.h" #include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -41,12 +39,9 @@ class CudnnFusionOp { ...@@ -41,12 +39,9 @@ class CudnnFusionOp {
} }
~CudnnFusionOp() { ~CudnnFusionOp() {
// New 'fused op' descriptor destruction dynload::cudnnDestroyFusedOpsVariantParamPack(op_variant_params_);
PADDLE_ENFORCE_CUDA_SUCCESS( dynload::cudnnDestroyFusedOpsConstParamPack(op_const_params_);
dynload::cudnnDestroyFusedOpsVariantParamPack(op_variant_params_)); dynload::cudnnDestroyFusedOpsPlan(op_);
PADDLE_ENFORCE_CUDA_SUCCESS(
dynload::cudnnDestroyFusedOpsConstParamPack(op_const_params_));
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroyFusedOpsPlan(op_));
} }
// Execute fused op // Execute fused op
...@@ -121,41 +116,49 @@ class CudnnFusionOp { ...@@ -121,41 +116,49 @@ class CudnnFusionOp {
// Get the workspace, which is required before Execute(). // Get the workspace, which is required before Execute().
size_t GetWorkspaceSizeInBytes(cudnnHandle_t cudnn_handle) { size_t GetWorkspaceSizeInBytes(cudnnHandle_t cudnn_handle) {
size_t workspace_bytes = 0U; if (!plan_created_) {
workspace_bytes_ = 0U;
PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnMakeFusedOpsPlan( PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnMakeFusedOpsPlan(
cudnn_handle, op_, op_const_params_, &workspace_bytes)); cudnn_handle, op_, op_const_params_, &workspace_bytes_));
plan_created_ = true; plan_created_ = true;
return workspace_bytes; }
return workspace_bytes_;
} }
private: private:
bool plan_created_; bool plan_created_;
size_t workspace_bytes_;
cudnnFusedOpsPlan_t op_; cudnnFusedOpsPlan_t op_;
cudnnFusedOpsConstParamPack_t op_const_params_; cudnnFusedOpsConstParamPack_t op_const_params_;
cudnnFusedOpsVariantParamPack_t op_variant_params_; cudnnFusedOpsVariantParamPack_t op_variant_params_;
}; };
static inline std::vector<int> GetStrides(const std::vector<int> &shape) { class CudnnFusionOpCache {
if (shape.size() < 1) { public:
return {}; static CudnnFusionOpCache &Instance() {
} static CudnnFusionOpCache instance;
int dim = static_cast<int>(shape.size()); return instance;
std::vector<int> pro_shape(shape); }
std::vector<int> strides(dim);
int temp = pro_shape[1]; framework::AlgorithmsCache<CudnnFusionOp *> *GetForward() {
pro_shape.erase(pro_shape.begin() + 1); return &forward_cache_;
pro_shape.push_back(temp); }
strides.back() = 1; framework::AlgorithmsCache<CudnnFusionOp *> *GetBackward() {
for (int i = dim - 2; i >= 0; --i) { return &backward_cache_;
strides[i] = strides[i + 1] * pro_shape[i + 1]; }
}
strides.pop_back(); private:
strides.insert(strides.begin() + 1, 1); CudnnFusionOpCache() {}
return strides; ~CudnnFusionOpCache() {
} // Need to delete the memory of cache.
}
static inline int64_t AlignUp(int64_t a, int64_t b) { return (a + b - 1) / b; } CudnnFusionOpCache(const CudnnFusionOpCache &) {}
private:
framework::AlgorithmsCache<CudnnFusionOp *> forward_cache_;
framework::AlgorithmsCache<CudnnFusionOp *> backward_cache_;
};
#endif // CUDNN_VERSION >= 8000 #endif // CUDNN_VERSION >= 8000
} // namespace operators } // namespace operators
......
...@@ -15,125 +15,320 @@ limitations under the License. */ ...@@ -15,125 +15,320 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h" #include "paddle/fluid/operators/fused/cudnn_fusion_helper.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
namespace dynload = platform::dynload; namespace dynload = platform::dynload;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
#if CUDNN_VERSION >= 8000 #if CUDNN_VERSION >= 8000
static size_t RoundUp(int64_t a, int64_t b) { return (a + b - 1) / b * b; }
template <typename T> template <typename T>
class CudnnNormConvolutionOp { struct NormConvolutionArgs {
public: NormConvolutionArgs() {
CudnnNormConvolutionOp() dtype = platform::CudnnDataType<T>::type;
: fwd_op_(CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS) {} format = CUDNN_TENSOR_NHWC;
~CudnnNormConvolutionOp() {} compute_type = platform::CudnnDataType<float>::type;
}
void Set(const std::vector<int> &input_shape,
const std::vector<int> &filter_shape,
const std::vector<int> &output_shape, int padding, int stride,
int dilation, int group) {
PADDLE_ENFORCE_EQ(
input_shape.size(), 4U,
platform::errors::InvalidArgument(
"The size of input_shape is expected to 4. But recieved "
"input_shape's size is %d, input_shape is [%s].",
input_shape.size(), framework::make_ddim(input_shape)));
PADDLE_ENFORCE_EQ(
filter_shape.size(), 4U,
platform::errors::InvalidArgument(
"The size of filter_shape is expected to 4. But recieved "
"filter_shape's size is %d, filter_shape is [%s].",
filter_shape.size(), framework::make_ddim(filter_shape)));
PADDLE_ENFORCE_EQ(filter_shape[1] == filter_shape[2] &&
(filter_shape[1] == 1 || filter_shape[1] == 3),
true,
platform::errors::InvalidArgument(
"The filter_shape is expected to store as nhwc, and "
"h = w = 1 or 3. But recieved filter_shape is [%s].",
framework::make_ddim(filter_shape)));
PADDLE_ENFORCE_EQ(
output_shape.size(), 4U,
platform::errors::InvalidArgument(
"The size of output_shape is expected to 4. But recieved "
"filter_shape's size is %d, filter_shape is [%s].",
output_shape.size(), framework::make_ddim(output_shape)));
for (size_t i = 0; i < input_shape.size(); ++i) {
in_dims.push_back(input_shape[i]);
}
for (size_t i = 0; i < filter_shape.size(); ++i) {
filter_dims.push_back(filter_shape[i]);
}
paddings = {padding, padding};
strides = {stride, stride};
dilations = {dilation, dilation};
in_desc.set(input_shape, format, dtype);
filter_desc.set(filter_shape, format, dtype, group);
out_desc.set(output_shape, format, dtype);
int output_channel = filter_shape[0];
std::vector<int> stats_shape = {1, 1, 1, output_channel};
out_stats_desc.set(stats_shape, format, compute_type);
conv_desc.set(dtype, paddings, strides, dilations, false, group);
}
void Init(const platform::CUDADeviceContext &ctx, cudnnDataType_t dtype;
cudnnTensorFormat_t format;
cudnnDataType_t compute_type;
std::vector<int64_t> in_dims;
std::vector<int64_t> filter_dims;
std::vector<int> strides;
std::vector<int> paddings;
std::vector<int> dilations;
platform::TensorDescriptor in_desc;
platform::FilterDescriptor filter_desc;
platform::TensorDescriptor out_desc;
platform::TensorDescriptor out_stats_desc;
platform::ConvolutionDescriptor conv_desc;
};
template <typename T>
class CudnnNormConvolution {
public:
CudnnNormConvolution(const platform::CUDADeviceContext &ctx,
const std::vector<int> &input_shape, const std::vector<int> &input_shape,
const std::vector<int> &filter_shape, const std::vector<int> &filter_shape,
const std::vector<int> &output_shape, const int &pad, const std::vector<int> &output_shape, const int &padding,
const int &stride, const int &dilate, const int &group) { const int &stride, const int &dilation,
cudnn_fwd_compute_type_ = platform::CudnnDataType<float>::type; const int &group) {
dtype_ = platform::CudnnDataType<T>::type; args_.Set(input_shape, filter_shape, output_shape, padding, stride,
format_ = CUDNN_TENSOR_NHWC; dilation, group);
InitDescriptors(ctx, input_shape, filter_shape, output_shape, pad, stride,
dilate, group);
GetWorkspaceSize(ctx);
} }
~CudnnNormConvolution() {}
void Forward(const platform::CUDADeviceContext &ctx, T *input_ptr, void Forward(const platform::CUDADeviceContext &ctx, T *input_ptr,
T *filter_ptr, T *output_ptr, float *sum_ptr, T *filter_ptr, T *output_ptr, float *sum_ptr,
float *sum_of_squares_ptr) { float *sum_of_squares_ptr) {
auto handle = ctx.cudnn_handle(); auto cudnn_handle = ctx.cudnn_handle();
auto workspace_handle = ctx.cudnn_workspace_handle();
CudnnFusionOp *fwd_op = GetForwardOp(ctx);
size_t workspace_size = RoundUp(
static_cast<int64_t>(fwd_op->GetWorkspaceSizeInBytes(cudnn_handle)),
512);
// Set variant_param // Set variant_param
// input ptr // input ptr
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WDATA, filter_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WDATA, filter_ptr);
fwd_op_.SetOpVariantParamAttrPtr( fwd_op->SetOpVariantParamAttrPtr(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &fwd_workspace_byte_); CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size);
// output ptr // output ptr
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr);
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr);
workspace_handle.RunFunc(
ctx.cudnn_workspace_handle().RunFunc(
[&](void *workspace_ptr) { [&](void *workspace_ptr) {
// workspace ptr // workspace ptr
fwd_op_.SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr); fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr);
// fused op execute // fused op execute
fwd_op_.Execute(handle); fwd_op->Execute(cudnn_handle);
}, },
fwd_workspace_byte_); workspace_size);
} }
// TBD
void Backward(const platform::CUDADeviceContext &ctx) {}
private: private:
void InitDescriptors(const platform::CUDADeviceContext &ctx, CudnnFusionOp *GetForwardOp(const platform::CUDADeviceContext &ctx) {
const std::vector<int> &input_shape, framework::AlgorithmsCache<CudnnFusionOp *> &cache =
const std::vector<int> &filter_shape, *(CudnnFusionOpCache::Instance().GetForward());
const std::vector<int> &output_shape, const int &pad,
const int &stride, const int &dilate, const int &group) { CudnnFusionOp *fwd_op = cache.GetAlgorithm(
args_.in_dims, args_.filter_dims, args_.strides, args_.paddings,
args_.dilations, 0, static_cast<int64_t>(args_.dtype), [&]() {
CudnnFusionOp *fwd_op =
new CudnnFusionOp(CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS);
// Set constant_param // Set constant_param
fwd_op_.SetOpConstParamAttr( fwd_op->SetOpConstParamAttr(
{CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_WDATA_PLACEHOLDER, {CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_WDATA_PLACEHOLDER,
CUDNN_PARAM_YDATA_PLACEHOLDER}, CUDNN_PARAM_YDATA_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED); CUDNN_PTR_16B_ALIGNED);
fwd_op_.SetOpConstParamAttr( fwd_op->SetOpConstParamAttr(
{CUDNN_PARAM_YSUM_PLACEHOLDER, CUDNN_PARAM_YSQSUM_PLACEHOLDER}, {CUDNN_PARAM_YSUM_PLACEHOLDER, CUDNN_PARAM_YSQSUM_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED); CUDNN_PTR_16B_ALIGNED);
std::vector<int> pad_vec = {pad, pad}; // conv desc
std::vector<int> stride_vec = {stride, stride}; fwd_op->SetOpConstParamDesc(CUDNN_PARAM_CONV_DESC,
std::vector<int> dilate_vec = {dilate, dilate}; args_.conv_desc.desc());
int output_channel = filter_shape[0]; // input desc
std::vector<int> stats_shape = {1, 1, 1, output_channel}; fwd_op->SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
// filter desc
fwd_op->SetOpConstParamDesc(CUDNN_PARAM_WDESC,
args_.filter_desc.desc());
// output desc
fwd_op->SetOpConstParamDesc(CUDNN_PARAM_YDESC, args_.out_desc.desc());
// output_stats desc
fwd_op->SetOpConstParamDesc(CUDNN_PARAM_YSTATS_DESC,
args_.out_stats_desc.desc());
// batch_norm mode
fwd_op->SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
// set conv desc // Make cudnn fused ops plan
conv_desc_.set(dtype_, pad_vec, stride_vec, dilate_vec, false, group); fwd_op->GetWorkspaceSizeInBytes(ctx.cudnn_handle());
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_CONV_DESC, conv_desc_.desc()); return fwd_op;
});
return fwd_op;
}
// set input desc private:
in_desc_.set(input_shape, format_, dtype_); NormConvolutionArgs<T> args_;
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_XDESC, in_desc_.desc()); };
// set filter desc template <typename T>
filter_desc_.set(filter_shape, format_, dtype_, group); class CudnnNormConvolutionGrad {
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_WDESC, filter_desc_.desc()); public:
CudnnNormConvolutionGrad(const platform::CUDADeviceContext &ctx,
const std::vector<int> &input_shape,
const std::vector<int> &filter_shape,
const std::vector<int> &output_shape,
const int &padding, const int &stride,
const int &dilation, const int &group) {
args_.Set(input_shape, filter_shape, output_shape, padding, stride,
dilation, group);
dgrad_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
}
~CudnnNormConvolutionGrad() {}
void Backward(const platform::CUDADeviceContext &ctx, T *input_ptr,
T *output_grad_ptr, T *filter_ptr, T *input_grad_ptr,
T *filter_grad_ptr, bool use_addto = false) {
if (filter_grad_ptr) {
BackwardFilter(ctx, input_ptr, output_grad_ptr, filter_ptr,
filter_grad_ptr);
}
if (input_grad_ptr) {
BackwardData(ctx, input_ptr, output_grad_ptr, filter_ptr, input_grad_ptr,
use_addto);
}
}
private:
void BackwardFilter(const platform::CUDADeviceContext &ctx, T *input_ptr,
T *output_grad_ptr, T *filter_ptr, T *filter_grad_ptr) {
auto cudnn_handle = ctx.cudnn_handle();
// set output desc CudnnFusionOp *wgrad_op = GetBackwardFilterOp(ctx);
out_desc_.set(output_shape, format_, dtype_); size_t workspace_size = RoundUp(
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_YDESC, out_desc_.desc()); static_cast<int64_t>(wgrad_op->GetWorkspaceSizeInBytes(cudnn_handle)),
512);
// set output_stats desc wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr);
out_stats_desc_.set(stats_shape, format_, cudnn_fwd_compute_type_); wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, output_grad_ptr);
fwd_op_.SetOpConstParamDesc(CUDNN_PARAM_YSTATS_DESC, wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_DWDATA, filter_grad_ptr);
out_stats_desc_.desc()); wgrad_op->SetOpVariantParamAttrPtr(
CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size);
fwd_op_.SetOpConstParamAttr(CUDNN_PARAM_BN_MODE, CUDNN_BATCHNORM_SPATIAL); ctx.cudnn_workspace_handle().RunFunc(
[&](void *workspace_ptr) {
// workspace ptr
wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE,
workspace_ptr);
// fused op execute
wgrad_op->Execute(cudnn_handle);
},
workspace_size);
} }
void GetWorkspaceSize(const platform::CUDADeviceContext &ctx) { void BackwardData(const platform::CUDADeviceContext &ctx, T *input_ptr,
auto handle = ctx.cudnn_handle(); T *output_grad_ptr, T *filter_ptr, T *input_grad_ptr,
fwd_workspace_byte_ = fwd_op_.GetWorkspaceSizeInBytes(handle); bool use_addto = false) {
auto cudnn_handle = ctx.cudnn_handle();
size_t workspace_size = GetWorkspaceSizeBwdData(ctx);
// Convolution dgrad followed optionally by batchnorm dgrad
ScalingParamType<T> alpha = 1.0f;
ScalingParamType<T> beta = use_addto ? 1.0f : 0.0f;
ctx.cudnn_workspace_handle().RunFunc(
[&](void *cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnConvolutionBackwardData(
cudnn_handle, &alpha, args_.filter_desc.desc(), filter_ptr,
args_.out_desc.desc(), output_grad_ptr,
args_.conv_desc.desc(), dgrad_algo_, cudnn_workspace_ptr,
workspace_size, &beta, args_.in_desc.desc(), input_grad_ptr));
},
workspace_size);
} }
size_t fwd_workspace_byte_ = 0; CudnnFusionOp *GetBackwardFilterOp(const platform::CUDADeviceContext &ctx) {
framework::AlgorithmsCache<CudnnFusionOp *> &cache =
*(CudnnFusionOpCache::Instance().GetBackward());
CudnnFusionOp *wgrad_op = cache.GetAlgorithm(
args_.in_dims, args_.filter_dims, args_.strides, args_.paddings,
args_.dilations, 0, static_cast<int64_t>(args_.dtype), [&]() {
CudnnFusionOp *wgrad_op =
new CudnnFusionOp(CUDNN_FUSED_SCALE_BIAS_ACTIVATION_WGRAD);
wgrad_op->SetOpConstParamAttr(
{CUDNN_PARAM_DYDATA_PLACEHOLDER, CUDNN_PARAM_XDATA_PLACEHOLDER,
CUDNN_PARAM_DWDATA_PLACEHOLDER},
CUDNN_PTR_16B_ALIGNED);
// conv desc
wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_CONV_DESC,
args_.conv_desc.desc());
// input desc
wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_XDESC,
args_.in_desc.desc());
// filter desc
wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_DWDESC,
args_.filter_desc.desc());
// output desc
wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_DYDESC,
args_.out_desc.desc());
wgrad_op->SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
// Make cudnn fused ops plan
wgrad_op->GetWorkspaceSizeInBytes(ctx.cudnn_handle());
return wgrad_op;
});
return wgrad_op;
}
cudnnDataType_t dtype_; size_t GetWorkspaceSizeBwdData(const platform::CUDADeviceContext &ctx) {
cudnnDataType_t cudnn_fwd_compute_type_; size_t workspace_size = 0U;
platform::TensorDescriptor in_desc_; auto handle = ctx.cudnn_handle();
platform::FilterDescriptor filter_desc_; PADDLE_ENFORCE_CUDA_SUCCESS(
platform::TensorDescriptor out_desc_; platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
platform::TensorDescriptor out_stats_desc_; handle, args_.filter_desc.desc(), args_.out_desc.desc(),
platform::ConvolutionDescriptor conv_desc_; args_.conv_desc.desc(), args_.in_desc.desc(), dgrad_algo_,
cudnnTensorFormat_t format_; &workspace_size));
return RoundUp(workspace_size, 512);
}
CudnnFusionOp fwd_op_; private:
NormConvolutionArgs<T> args_;
cudnnConvolutionBwdDataAlgo_t dgrad_algo_;
}; };
#endif #endif
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <random> #include <random>
#include <vector> #include <vector>
...@@ -29,23 +30,80 @@ namespace op = paddle::operators; ...@@ -29,23 +30,80 @@ namespace op = paddle::operators;
using Tensor = paddle::framework::Tensor; using Tensor = paddle::framework::Tensor;
USE_OP(conv2d); USE_OP(conv2d);
USE_OP(conv2d_grad);
USE_OP_DEVICE_KERNEL(conv2d, CUDNN); USE_OP_DEVICE_KERNEL(conv2d, CUDNN);
USE_OP_DEVICE_KERNEL(conv2d_grad, CUDNN);
template <typename T>
void InitRandomTensor(const std::vector<int64_t> &dims,
framework::Tensor *cpu_out) {
T *cpu_out_ptr = cpu_out->mutable_data<T>(framework::make_ddim(dims),
platform::CPUPlace());
std::default_random_engine random(0);
std::uniform_real_distribution<float> dis(0.0, 1.0);
for (int i = 0; i < cpu_out->numel(); ++i) {
cpu_out_ptr[i] = static_cast<T>(dis(random));
}
}
template <typename T>
void TransposeNchwToNhwc(const framework::Tensor &cpu_in,
framework::Tensor *cpu_out) {
auto in_dims = cpu_in.dims();
EXPECT_EQ(cpu_in.dims().size(), 4);
const T *cpu_in_ptr = cpu_in.data<T>();
T *cpu_out_ptr = cpu_out->mutable_data<T>(
{in_dims[0], in_dims[2], in_dims[3], in_dims[1]}, platform::CPUPlace());
int64_t n = in_dims[0];
int64_t c = in_dims[1];
int64_t hw = in_dims[2] * in_dims[3];
for (int i = 0; i < n; ++i) {
for (int j = 0; j < hw; ++j) {
for (int k = 0; k < c; ++k) {
int dst_idx = i * hw * c + j * c + k;
int src_idx = i * c * hw + k * hw + j;
cpu_out_ptr[dst_idx] = cpu_in_ptr[src_idx];
}
}
}
}
template <typename T>
void CheckOutput(const framework::Tensor &cpu_res,
const framework::Tensor &cpu_base, float diff,
bool is_relative_atol = false) {
EXPECT_EQ(cpu_res.dims(), cpu_base.dims());
const T *cpu_res_ptr = cpu_res.data<T>();
const T *cpu_base_ptr = cpu_base.data<T>();
for (int i = 0; i < cpu_res.numel(); ++i) {
if (is_relative_atol) {
EXPECT_LT(static_cast<float>(std::abs((cpu_res_ptr[i] - cpu_base_ptr[i]) /
cpu_base_ptr[i])),
diff);
} else {
EXPECT_LT(static_cast<float>(std::abs(cpu_res_ptr[i] - cpu_base_ptr[i])),
diff);
}
}
}
// get paddle conv2d op results as baseline // Use Paddle conv2d op results as baseline
template <typename T> template <typename T>
void Conv2DForwardCompute(const Tensor &x, const Tensor &w, Tensor *y, void ComputeConv2DForward(const platform::CUDADeviceContext &ctx,
const platform::CUDADeviceContext &ctx) { const Tensor &cpu_input, const Tensor &cpu_filter,
Tensor *cpu_output) {
framework::Scope scope; framework::Scope scope;
auto var_x = scope.Var("Input"); auto *input = scope.Var("Input")->GetMutable<framework::LoDTensor>();
auto tensor_x = var_x->GetMutable<framework::LoDTensor>(); auto *filter = scope.Var("Filter")->GetMutable<framework::LoDTensor>();
auto var_w = scope.Var("Filter"); auto *output = scope.Var("Output")->GetMutable<framework::LoDTensor>();
auto tensor_w = var_w->GetMutable<framework::LoDTensor>();
auto var_y = scope.Var("Output");
auto tensor_y = var_y->GetMutable<framework::LoDTensor>();
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
TensorCopySync(x, place, tensor_x); TensorCopySync(cpu_input, place, input);
TensorCopySync(w, place, tensor_w); TensorCopySync(cpu_filter, place, filter);
framework::AttributeMap attrs; framework::AttributeMap attrs;
bool use_cudnn = true; bool use_cudnn = true;
...@@ -60,25 +118,94 @@ void Conv2DForwardCompute(const Tensor &x, const Tensor &w, Tensor *y, ...@@ -60,25 +118,94 @@ void Conv2DForwardCompute(const Tensor &x, const Tensor &w, Tensor *y,
{{"Output", {"Output"}}}, attrs); {{"Output", {"Output"}}}, attrs);
op->Run(scope, ctx.GetPlace()); op->Run(scope, ctx.GetPlace());
TensorCopySync(*tensor_y, place, y); TensorCopySync(*output, platform::CPUPlace(), cpu_output);
ctx.Wait();
} }
// Use Paddle conv2d_grad op results as baseline
template <typename T> template <typename T>
class TestCudnnNormConvOpForward { void ComputeConv2DBackward(const platform::CUDADeviceContext &ctx,
public: const Tensor &cpu_input, const Tensor &cpu_filter,
TestCudnnNormConvOpForward() { const Tensor &cpu_output_grad,
batch_size_ = 2; framework::Tensor *cpu_input_grad,
height_ = 8; framework::Tensor *cpu_filter_grad, int stride,
width_ = 8; int padding, int dilation) {
input_channels_ = 8; framework::Scope scope;
output_channels_ = 32; auto *input = scope.Var("Input")->GetMutable<framework::LoDTensor>();
kernel_size_ = 1; auto *filter = scope.Var("Filter")->GetMutable<framework::LoDTensor>();
stride_ = 1; auto *output_grad =
pad_ = 0; scope.Var("Output@GRAD")->GetMutable<framework::LoDTensor>();
auto *input_grad =
scope.Var("Input@GRAD")->GetMutable<framework::LoDTensor>();
auto *filter_grad =
scope.Var("Filter@GRAD")->GetMutable<framework::LoDTensor>();
auto place = ctx.GetPlace();
TensorCopySync(cpu_input, place, input);
TensorCopySync(cpu_filter, place, filter);
TensorCopySync(cpu_output_grad, place, output_grad);
framework::AttributeMap attrs;
bool use_cudnn = true;
std::string data_format = "NHWC";
std::string padding_algorithm = "SAME";
std::vector<int> strides = {stride, stride};
std::vector<int> paddings = {padding, padding};
std::vector<int> dilations = {dilation, dilation};
int groups = 1;
bool exhaustive_search = false;
bool use_addto = false;
attrs.insert({"use_cudnn", use_cudnn});
attrs.insert({"data_format", data_format});
attrs.insert({"padding_algorithm", padding_algorithm});
attrs.insert({"strides", strides});
attrs.insert({"paddings", paddings});
attrs.insert({"dilations", dilations});
attrs.insert({"groups", groups});
attrs.insert({"exhaustive_search", exhaustive_search});
attrs.insert({"use_addto", use_addto});
auto op = framework::OpRegistry::CreateOp(
"conv2d_grad", {{"Input", {"Input"}},
{"Filter", {"Filter"}},
{"Output@GRAD", {"Output@GRAD"}}},
{{"Input@GRAD", {"Input@GRAD"}}, {"Filter@GRAD", {"Filter@GRAD"}}},
attrs);
op->Run(scope, ctx.GetPlace());
TensorCopySync(*input_grad, platform::CPUPlace(), cpu_input_grad);
TensorCopySync(*filter_grad, platform::CPUPlace(), cpu_filter_grad);
}
template <typename T>
void ComputeSumAndSquareSum(const framework::Tensor &cpu_out,
framework::Tensor *cpu_sum,
framework::Tensor *cpu_sum_of_square) {
auto dims = cpu_out.dims();
int64_t c = dims[3];
const T *cpu_out_ptr = cpu_out.data<T>();
float *cpu_sum_ptr =
cpu_sum->mutable_data<float>({1, 1, 1, c}, platform::CPUPlace());
float *cpu_sum_square_ptr = cpu_sum_of_square->mutable_data<float>(
{1, 1, 1, c}, platform::CPUPlace());
for (int j = 0; j < c; ++j) {
float tmp_sum = 0.0f;
float tmp_sum_of_squares = 0.0f;
for (int i = 0; i < cpu_out.numel() / c; ++i) {
float tmp_out = static_cast<float>(cpu_out_ptr[i * c + j]);
tmp_sum += tmp_out;
tmp_sum_of_squares += tmp_out * tmp_out;
}
cpu_sum_ptr[j] = tmp_sum;
cpu_sum_square_ptr[j] = tmp_sum_of_squares;
} }
}
TestCudnnNormConvOpForward(int batch_size, int height, int width, template <typename T>
class CudnnNormConvolutionTester {
public:
CudnnNormConvolutionTester(int batch_size, int height, int width,
int input_channels, int output_channels, int input_channels, int output_channels,
int kernel_size, int stride) { int kernel_size, int stride) {
batch_size_ = batch_size; batch_size_ = batch_size;
...@@ -88,133 +215,183 @@ class TestCudnnNormConvOpForward { ...@@ -88,133 +215,183 @@ class TestCudnnNormConvOpForward {
output_channels_ = output_channels; output_channels_ = output_channels;
kernel_size_ = kernel_size; kernel_size_ = kernel_size;
stride_ = stride; stride_ = stride;
pad_ = (kernel_size_ - 1) / 2; padding_ = (kernel_size_ - 1) / 2;
SetUp();
} }
~TestCudnnNormConvOpForward() {} ~CudnnNormConvolutionTester() {}
void SetUp() { void CheckForward(float diff, bool is_relative_atol = false) {
input_size_ = batch_size_ * height_ * width_ * input_channels_; platform::CUDADeviceContext *ctx =
filter_size_ = static_cast<platform::CUDADeviceContext *>(
output_channels_ * input_channels_ * kernel_size_ * kernel_size_; platform::DeviceContextPool::Instance().Get(
output_size_ = batch_size_ * height_ * width_ * output_channels_; platform::CUDAPlace(0)));
param_size_ = output_channels_;
input_vec_.resize(input_size_); framework::Tensor cpu_output_base;
filter_raw_vec_.resize(filter_size_); framework::Tensor cpu_sum_base;
filter_pro_vec_.resize(filter_size_); framework::Tensor cpu_sum_of_square_base;
BaselineForward(*ctx, &cpu_output_base, &cpu_sum_base,
&cpu_sum_of_square_base);
std::default_random_engine random(0); framework::Tensor cpu_output;
std::uniform_real_distribution<float> dis(0.0, 1.0); framework::Tensor cpu_sum;
for (int i = 0; i < input_size_; ++i) { framework::Tensor cpu_sum_of_square;
input_vec_[i] = static_cast<T>(dis(random)); FusedForward(*ctx, &cpu_output, &cpu_sum, &cpu_sum_of_square);
}
for (int i = 0; i < filter_size_; ++i) { // Check forward correctness between baseline and results of normconv.
filter_raw_vec_[i] = static_cast<T>(dis(random)); CheckOutput<T>(cpu_output, cpu_output_base, diff, is_relative_atol);
} CheckOutput<float>(cpu_sum, cpu_sum_base, diff, is_relative_atol);
// transpoes for filter CheckOutput<float>(cpu_sum_of_square, cpu_sum_of_square_base, diff,
// NCHW->NHWC is_relative_atol);
for (int oc = 0; oc < output_channels_; ++oc) {
for (int kh = 0; kh < kernel_size_; ++kh) {
for (int kw = 0; kw < kernel_size_; ++kw) {
for (int ic = 0; ic < input_channels_; ++ic) {
int dst_idx = oc * kernel_size_ * kernel_size_ * input_channels_ +
kh * kernel_size_ * input_channels_ +
kw * input_channels_ + ic;
int src_idx = oc * kernel_size_ * kernel_size_ * input_channels_ +
ic * kernel_size_ * kernel_size_ + kh * kernel_size_ +
kw;
filter_pro_vec_[dst_idx] = filter_raw_vec_[src_idx];
}
} }
void CheckBackward(float diff, bool is_relative_atol = false) {
platform::CUDADeviceContext *ctx =
static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(0)));
framework::Tensor cpu_input_grad_base;
framework::Tensor cpu_filter_nchw_grad_base;
framework::Tensor cpu_filter_nhwc_grad_base;
BaselineBackward(*ctx, &cpu_input_grad_base, &cpu_filter_nchw_grad_base);
TransposeNchwToNhwc<T>(cpu_filter_nchw_grad_base,
&cpu_filter_nhwc_grad_base);
framework::Tensor cpu_input_grad;
framework::Tensor cpu_filter_nhwc_grad;
FusedBackward(*ctx, &cpu_input_grad, &cpu_filter_nhwc_grad);
// Check backward correctness between baseline and results of normconv.
CheckOutput<T>(cpu_input_grad, cpu_input_grad_base, diff, is_relative_atol);
CheckOutput<T>(cpu_filter_nhwc_grad, cpu_filter_nhwc_grad_base, diff,
is_relative_atol);
} }
private:
void SetUp() {
InitRandomTensor<T>({batch_size_, height_, width_, input_channels_},
&cpu_input_);
InitRandomTensor<T>(
{output_channels_, input_channels_, kernel_size_, kernel_size_},
&cpu_filter_nchw_);
// transpoes for filter, NCHW -> NHWC
TransposeNchwToNhwc<T>(cpu_filter_nchw_, &cpu_filter_nhwc_);
InitRandomTensor<T>({batch_size_, height_, width_, output_channels_},
&cpu_output_grad_);
} }
framework::TensorFromVector<T>(input_vec_, *ctx_, &input_); void BaselineForward(const platform::CUDADeviceContext &ctx,
input_.Resize({batch_size_, height_, width_, input_channels_}); framework::Tensor *cpu_output_base,
framework::TensorFromVector<T>(filter_raw_vec_, *ctx_, &filter_raw_); framework::Tensor *cpu_sum_base,
filter_raw_.Resize( framework::Tensor *cpu_sum_of_square_base) {
{output_channels_, input_channels_, kernel_size_, kernel_size_}); ComputeConv2DForward<T>(ctx, cpu_input_, cpu_filter_nchw_, cpu_output_base);
framework::TensorFromVector<T>(filter_pro_vec_, *ctx_, &filter_pro_); ComputeSumAndSquareSum<T>(*cpu_output_base, cpu_sum_base,
filter_pro_.Resize( cpu_sum_of_square_base);
{output_channels_, kernel_size_, kernel_size_, input_channels_});
output_.Resize({batch_size_, height_, width_, output_channels_});
base_output_.Resize({batch_size_, height_, width_, output_channels_});
sum_.Resize({1, 1, 1, output_channels_});
sum_of_squares_.Resize({1, 1, 1, output_channels_});
ctx_->Wait();
} }
void BaselineForward() { void BaselineBackward(const platform::CUDADeviceContext &ctx,
Conv2DForwardCompute<T>(input_, filter_raw_, &base_output_, *ctx_); framework::Tensor *cpu_input_grad_base,
ctx_->Wait(); framework::Tensor *cpu_filter_grad_base) {
ComputeConv2DBackward<T>(ctx, cpu_input_, cpu_filter_nchw_,
cpu_output_grad_, cpu_input_grad_base,
cpu_filter_grad_base, stride_, padding_,
dilation_);
} }
// get forward results of cudnn_norm_conv // get forward results of cudnn_norm_conv
void FusedForward() { void FusedForward(const platform::CUDADeviceContext &ctx,
auto input_shape = framework::vectorize<int>(input_.dims()); framework::Tensor *cpu_output, framework::Tensor *cpu_sum,
auto filter_shape = framework::vectorize<int>(filter_pro_.dims()); framework::Tensor *cpu_sum_of_square) {
auto output_shape = framework::vectorize<int>(output_.dims()); framework::Tensor input;
T *input_ptr = input_.data<T>(); framework::Tensor filter_nhwc;
T *filter_ptr = filter_pro_.data<T>(); framework::Tensor output;
T *output_ptr = output_.mutable_data<T>(place_); framework::Tensor sum;
float *sum_ptr = sum_.mutable_data<float>(place_); framework::Tensor sum_of_square;
float *sum_of_squares_ptr = sum_of_squares_.mutable_data<float>(place_);
std::shared_ptr<op::CudnnNormConvolutionOp<T>> conv_op(
new op::CudnnNormConvolutionOp<T>());
conv_op->Init(*ctx_, input_shape, filter_shape, output_shape, pad_, stride_,
dilate_, group_);
conv_op->Forward(*ctx_, input_ptr, filter_ptr, output_ptr, sum_ptr,
sum_of_squares_ptr);
ctx_->Wait();
}
void Run() { auto place = ctx.GetPlace();
SetUp(); TensorCopySync(cpu_input_, place, &input);
BaselineForward(); TensorCopySync(cpu_filter_nhwc_, place, &filter_nhwc);
FusedForward();
}
// check forward correctness between baseline and results of normconv. T *input_ptr = input.data<T>();
void CheckOut(const T diff, bool is_relative_atol = false) { T *filter_ptr = filter_nhwc.data<T>();
std::vector<T> base_output_vec, output_vec; T *output_ptr = output.mutable_data<T>(
output_vec.resize(output_size_); {batch_size_, height_, width_, output_channels_}, place);
base_output_vec.resize(output_size_); float *sum_ptr =
TensorToVector(base_output_, *ctx_, &base_output_vec); sum.mutable_data<float>({1, 1, 1, output_channels_}, place);
TensorToVector(output_, *ctx_, &output_vec); float *sum_of_square_ptr =
ctx_->Wait(); sum_of_square.mutable_data<float>({1, 1, 1, output_channels_}, place);
for (int i = 0; i < output_size_; ++i) { auto input_shape = framework::vectorize<int>(input.dims());
if (is_relative_atol) { auto filter_shape = framework::vectorize<int>(filter_nhwc.dims());
EXPECT_LT( auto output_shape = framework::vectorize<int>(output.dims());
std::abs((output_vec[i] - base_output_vec[i]) / base_output_vec[i]), op::CudnnNormConvolution<T> conv_op(ctx, input_shape, filter_shape,
diff); output_shape, padding_, stride_,
} else { dilation_, group_);
EXPECT_LT(std::abs(output_vec[i] - base_output_vec[i]), diff); conv_op.Forward(ctx, input_ptr, filter_ptr, output_ptr, sum_ptr,
} sum_of_square_ptr);
TensorCopySync(output, platform::CPUPlace(), cpu_output);
TensorCopySync(sum, platform::CPUPlace(), cpu_sum);
TensorCopySync(sum_of_square, platform::CPUPlace(), cpu_sum_of_square);
} }
void FusedBackward(const platform::CUDADeviceContext &ctx,
framework::Tensor *cpu_input_grad,
framework::Tensor *cpu_filter_grad) {
framework::Tensor input;
framework::Tensor filter_nhwc;
framework::Tensor output_grad;
framework::Tensor input_grad;
framework::Tensor filter_grad;
auto place = ctx.GetPlace();
TensorCopySync(cpu_input_, place, &input);
TensorCopySync(cpu_filter_nhwc_, place, &filter_nhwc);
TensorCopySync(cpu_output_grad_, place, &output_grad);
T *input_ptr = input.data<T>();
T *filter_ptr = filter_nhwc.data<T>();
T *output_grad_ptr = output_grad.data<T>();
T *input_grad_ptr = input_grad.mutable_data<T>(input.dims(), place);
T *filter_grad_ptr = filter_grad.mutable_data<T>(filter_nhwc.dims(), place);
auto input_shape = framework::vectorize<int>(input.dims());
auto filter_shape = framework::vectorize<int>(filter_nhwc.dims());
auto output_shape = framework::vectorize<int>(output_grad.dims());
op::CudnnNormConvolutionGrad<T> conv_grad_op(ctx, input_shape, filter_shape,
output_shape, padding_,
stride_, dilation_, group_);
conv_grad_op.Backward(ctx, input_ptr, output_grad_ptr, filter_ptr,
input_grad_ptr, filter_grad_ptr);
TensorCopySync(input_grad, platform::CPUPlace(), cpu_input_grad);
TensorCopySync(filter_grad, platform::CPUPlace(), cpu_filter_grad);
} }
private: private:
int batch_size_, height_, width_, input_channels_, output_channels_; int batch_size_;
int kernel_size_, stride_, pad_; int height_;
const int dilate_ = 1; int width_;
int input_channels_;
int output_channels_;
int kernel_size_;
int stride_;
int padding_;
const int dilation_ = 1;
const int group_ = 1; const int group_ = 1;
int input_size_, filter_size_, output_size_, param_size_;
framework::Tensor input_, filter_raw_, filter_pro_, output_, base_output_; // Forward input
framework::Tensor sum_, sum_of_squares_; framework::Tensor cpu_input_;
std::vector<T> input_vec_, filter_raw_vec_, filter_pro_vec_; framework::Tensor cpu_filter_nchw_;
framework::Tensor cpu_filter_nhwc_;
platform::CUDAPlace place_ = platform::CUDAPlace(0); // Backward input
platform::CUDADeviceContext *ctx_ = framework::Tensor cpu_output_grad_;
static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place_));
}; };
// test for fp16, kernel = 1, output_channels = input_channels // test for fp16, kernel = 1, output_channels = input_channels
TEST(CudnnNormConvForward, GPUCudnnNormConvForward1Fp16) { TEST(CudnnNormConvFp16, K1S1) {
int batch_size = 4; int batch_size = 4;
int height = 56; int height = 56;
int width = 56; int width = 56;
...@@ -222,15 +399,15 @@ TEST(CudnnNormConvForward, GPUCudnnNormConvForward1Fp16) { ...@@ -222,15 +399,15 @@ TEST(CudnnNormConvForward, GPUCudnnNormConvForward1Fp16) {
int output_channels = 32; int output_channels = 32;
int kernel_size = 1; int kernel_size = 1;
int stride = 1; int stride = 1;
TestCudnnNormConvOpForward<paddle::platform::float16> test( CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size, batch_size, height, width, input_channels, output_channels, kernel_size,
stride); stride);
test.Run(); test.CheckForward(1e-3, true);
test.CheckOut(static_cast<paddle::platform::float16>(1e-3), true); test.CheckBackward(1e-3, true);
} }
// test for fp16, kernel = 3, output_channels = input_channels // test for fp16, kernel = 3, output_channels = input_channels
TEST(CudnnNormConvForward, GPUCudnnNormConvForward2Fp16) { TEST(CudnnNormConvFp16, K3S1) {
int batch_size = 4; int batch_size = 4;
int height = 56; int height = 56;
int width = 56; int width = 56;
...@@ -238,15 +415,15 @@ TEST(CudnnNormConvForward, GPUCudnnNormConvForward2Fp16) { ...@@ -238,15 +415,15 @@ TEST(CudnnNormConvForward, GPUCudnnNormConvForward2Fp16) {
int output_channels = 32; int output_channels = 32;
int kernel_size = 3; int kernel_size = 3;
int stride = 1; int stride = 1;
TestCudnnNormConvOpForward<paddle::platform::float16> test( CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size, batch_size, height, width, input_channels, output_channels, kernel_size,
stride); stride);
test.Run(); test.CheckForward(1e-3, true);
test.CheckOut(static_cast<paddle::platform::float16>(1e-3), true); test.CheckBackward(1e-3, true);
} }
// test for fp16, kernel = 1, output_channels = input_channels * 4 // test for fp16, kernel = 1, output_channels = input_channels * 4
TEST(CudnnNormConvForward, GPUCudnnNormConvForward3Fp16) { TEST(CudnnNormConvFp16, K1S1O4) {
int batch_size = 4; int batch_size = 4;
int height = 56; int height = 56;
int width = 56; int width = 56;
...@@ -254,9 +431,9 @@ TEST(CudnnNormConvForward, GPUCudnnNormConvForward3Fp16) { ...@@ -254,9 +431,9 @@ TEST(CudnnNormConvForward, GPUCudnnNormConvForward3Fp16) {
int output_channels = 128; int output_channels = 128;
int kernel_size = 1; int kernel_size = 1;
int stride = 1; int stride = 1;
TestCudnnNormConvOpForward<paddle::platform::float16> test( CudnnNormConvolutionTester<paddle::platform::float16> test(
batch_size, height, width, input_channels, output_channels, kernel_size, batch_size, height, width, input_channels, output_channels, kernel_size,
stride); stride);
test.Run(); test.CheckForward(1e-3, true);
test.CheckOut(static_cast<paddle::platform::float16>(1e-3), true); test.CheckBackward(1e-3, true);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册