未验证 提交 e32c9888 编写于 作者: Q qingqing01 提交者: GitHub

Double backward of conv2d. (#17211)

* Add conv2d_grad_grad_op
* Extracte the cuDNN conv algo searching code in conv_cudnn_helper.h.
    - Now use it in conv2d_grad_grad.
    - Will simply the searching code in conv2d and conv2d_grad in next PR.
* Enhance and fix bug in unit testing of gradient_checker.
* Support to fetch empty variables,return None in Python.
上级 f456c8be
......@@ -386,9 +386,10 @@ class ExecutionContext {
template <typename T>
T& GetKernelConfig(int idx) const {
PADDLE_ENFORCE(kernel_configs_ && kernel_configs_->size() > idx,
"%s selected kernel doesn't have kernel config %lu <= %d",
op_.Type().c_str(), kernel_configs_->size(), idx);
PADDLE_ENFORCE(
kernel_configs_ && kernel_configs_->size() > static_cast<size_t>(idx),
"%s selected kernel doesn't have kernel config %lu <= %d",
op_.Type().c_str(), kernel_configs_->size(), idx);
return *boost::get<std::shared_ptr<T>>(kernel_configs_->at(idx));
}
......
......@@ -644,6 +644,7 @@ class LeakyReluDoubleGrad : public framework::OperatorWithKernel {
//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
// dy = 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
public:
......@@ -655,11 +656,12 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
op->SetType("relu_grad_grad");
// input1: Out
op->SetInput("Out", Input("Out"));
// X@GRAD@GRAD: ddx
// input2: ddx
op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(Attrs());
// Out@GRAD@GRAD: ddy
// output1: ddy
op->SetOutput("DOut", InputGrad("Out"));
// output2: ddy
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<::paddle::framework::OpDesc>(op);
}
......
......@@ -54,7 +54,13 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
TensorCopySync(src_item, platform::CPUPlace(), &dst_item);
if (src_item.IsInitialized() && src_item.numel() > 0) {
TensorCopySync(src_item, platform::CPUPlace(), &dst_item);
} else {
// Not copy, if the src tensor is empty.
dst_item.clear();
dst_item.Resize({0});
}
dst_item.set_lod(src_item.lod());
VLOG(3) << "Fetch variable " << fetch_var_name << " to " << out_name;
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_desc.h"
namespace paddle {
namespace operators {
using framework::AlgorithmsCache;
struct ConvArgs {
cudnnHandle_t handle;
platform::TensorDescriptor idesc, odesc;
platform::FilterDescriptor wdesc;
platform::ConvolutionDescriptor cdesc;
const framework::Tensor *x, *w, *o;
// strides
std::vector<int> s;
// paddings
std::vector<int> p;
// dilations
std::vector<int> d;
ConvArgs(const framework::Tensor* x, const framework::Tensor* w,
const framework::Tensor* o, const std::vector<int> s,
const std::vector<int> p, const std::vector<int> d)
: x(x), w(w), o(o), s(s), p(p), d(d) {}
};
template <typename perf_t>
struct SearchAlgorithm {};
template <>
struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
using perf_t = cudnnConvolutionFwdAlgoPerf_t;
using algo_t = cudnnConvolutionFwdAlgo_t;
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic, int algo_cache_id,
const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
algo_t algo;
if (!exhaustive) {
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(),
args.odesc.desc(), CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
VLOG(3) << "choose algo " << algo;
} else {
AlgorithmsCache<algo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
CUDNN_ENFORCE(
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
args.handle, args.idesc.desc(), args.x->data<T>(),
args.wdesc.desc(), args.w->data<T>(), args.cdesc.desc(),
args.odesc.desc(), const_cast<T*>(args.o->data<T>()),
kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
perf_stat.data(), cudnn_workspace_ptr,
workspace_size_limit));
};
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
VLOG(3) << "FwdAlgo Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = perf_stat[i];
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
<< " " << stat.memory;
}
return perf_stat[0].algo;
});
}
VLOG(3) << "choose algo " << algo;
return algo;
}
static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
size_t workspace_size = 0;
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(),
args.odesc.desc(), algo, &workspace_size));
return workspace_size;
}
};
template <>
struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;
using algo_t = cudnnConvolutionBwdDataAlgo_t;
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic, int algo_cache_id,
const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
algo_t algo;
if (!exhaustive && !deterministic) {
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
args.handle, args.wdesc.desc(), args.idesc.desc(), args.cdesc.desc(),
args.odesc.desc(), CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
} else if (deterministic) {
return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} else {
AlgorithmsCache<algo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
CUDNN_ENFORCE(
platform::dynload::
cudnnFindConvolutionBackwardDataAlgorithmEx(
args.handle, args.wdesc.desc(), args.w->data<T>(),
args.odesc.desc(), args.o->data<T>(),
args.cdesc.desc(), args.idesc.desc(),
const_cast<T*>(args.x->data<T>()),
kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count,
perf_stat.data(), cudnn_workspace_ptr,
workspace_size_limit));
};
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
VLOG(3) << "BwdDataAlgo Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = perf_stat[i];
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
<< " " << stat.memory;
}
return perf_stat[0].algo;
});
}
VLOG(3) << "choose algo " << algo;
return algo;
}
static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
size_t workspace_size = 0;
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
args.handle, args.wdesc.desc(), args.idesc.desc(),
args.cdesc.desc(), args.odesc.desc(), algo, &workspace_size));
return workspace_size;
}
};
template <>
struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
using algo_t = cudnnConvolutionBwdFilterAlgo_t;
template <typename T>
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
bool deterministic, int algo_cache_id,
const framework::ExecutionContext& ctx) {
auto dtype = platform::CudnnDataType<T>::type;
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
algo_t algo;
if (!exhaustive && !deterministic) {
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
args.handle, args.idesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.wdesc.desc(),
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
} else if (deterministic) {
return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else {
AlgorithmsCache<algo_t>& algo_cache =
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
auto& dev_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
int returned_algo_count;
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
CUDNN_ENFORCE(
platform::dynload::
cudnnFindConvolutionBackwardFilterAlgorithmEx(
args.handle, args.idesc.desc(), args.x->data<T>(),
args.odesc.desc(), args.o->data<T>(),
args.cdesc.desc(), args.wdesc.desc(),
const_cast<T*>(args.w->data<T>()),
kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count,
perf_stat.data(), cudnn_workspace_ptr,
workspace_size_limit));
};
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
VLOG(3) << "BwdFilterAlgo Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = perf_stat[i];
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
<< " " << stat.memory;
}
return perf_stat[0].algo;
});
}
VLOG(3) << "choose algo " << algo;
return algo;
}
static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
size_t workspace_size = 0;
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
args.handle, args.idesc.desc(), args.odesc.desc(),
args.cdesc.desc(), args.wdesc.desc(), algo, &workspace_size));
return workspace_size;
}
};
} // namespace operators
} // namespace paddle
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/conv_cudnn_helper.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/assert.h"
......@@ -46,6 +47,23 @@ template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
using framework::AlgorithmsCache;
static inline void GetNCDHW(const framework::DDim& dims,
const DataLayout& layout, int* N, int* C, int* D,
int* H, int* W) {
*N = dims[0];
*C = layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1];
int i = layout == DataLayout::kNCHW ? 0 : 1;
if (dims.size() == 5) {
*D = dims[2 - i];
*H = dims[3 - i];
*W = dims[4 - i];
} else {
*D = 1;
*H = dims[2 - i];
*W = dims[3 - i];
}
}
template <typename T>
class CUDNNConvOpKernel : public framework::OpKernel<T> {
public:
......@@ -99,33 +117,13 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()), groups);
int input_channels = input->dims()[1];
int input_height, input_width, input_depth;
if (input->dims().size() == 5) {
input_depth = input->dims()[2];
input_height = input->dims()[3];
input_width = input->dims()[4];
} else { // dim size is enforced in InferShape
input_depth = 1;
input_height = input->dims()[2];
input_width = input->dims()[3];
}
int output_channels = filter->dims()[0];
int output_height, output_width, output_depth;
if (output->dims().size() == 5) {
output_depth = output->dims()[2];
output_height = output->dims()[3];
output_width = output->dims()[4];
} else {
output_depth = 1;
output_height = output->dims()[2];
output_width = output->dims()[3];
}
int i_n, i_c, i_d, i_h, i_w;
GetNCDHW(input->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w);
int o_n, o_c, o_d, o_h, o_w;
GetNCDHW(output->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w);
int group_offset_in =
input_channels / groups * input_height * input_width * input_depth;
int group_offset_out =
output_channels / groups * output_height * output_width * output_depth;
int group_offset_in = i_c / groups * i_h * i_w * i_d;
int group_offset_out = o_c / groups * o_h * o_w * o_d;
int group_offset_filter = filter->numel() / groups;
// ------------------- cudnn conv workspace ---------------------
size_t workspace_size_in_bytes; // final workspace to allocate.
......@@ -164,6 +162,9 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto x_dims = framework::vectorize(input->dims());
auto f_dims = framework::vectorize(filter->dims());
// TODO(dangqingqing) simplify the following code by SearchAlgorithm in
// conv_cudnn_helper.h
if ((!exhaustive_search) && (!half_float)) {
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
......@@ -315,34 +316,14 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
}
#endif
int input_channels = input->dims()[1];
int input_height, input_width, input_depth;
if (input->dims().size() == 5) {
input_depth = input->dims()[2];
input_height = input->dims()[3];
input_width = input->dims()[4];
} else { // dim size is enforced in InferShape
input_depth = 1;
input_height = input->dims()[2];
input_width = input->dims()[3];
}
int i_n, i_c, i_d, i_h, i_w;
GetNCDHW(input->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w);
int o_n, o_c, o_d, o_h, o_w;
GetNCDHW(output_grad->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h,
&o_w);
int output_grad_channels = filter->dims()[0];
int output_grad_height, output_grad_width, output_grad_depth;
if (input->dims().size() == 5) {
output_grad_depth = output_grad->dims()[2];
output_grad_height = output_grad->dims()[3];
output_grad_width = output_grad->dims()[4];
} else {
output_grad_depth = 1;
output_grad_height = output_grad->dims()[2];
output_grad_width = output_grad->dims()[3];
}
int group_offset_in =
input_channels / groups * input_height * input_width * input_depth;
int group_offset_out = output_grad_channels / groups * output_grad_height *
output_grad_width * output_grad_depth;
int group_offset_in = i_c / groups * i_h * i_w * i_d;
int group_offset_out = o_c / groups * o_h * o_w * o_d;
int group_offset_filter = filter->numel() / groups;
// ------------------- cudnn backward algorithm ---------------------
cudnnConvolutionBwdDataAlgo_t data_algo;
......@@ -367,6 +348,8 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
}
// TODO(dangqingqing) simplify the following code by SearchAlgorithm in
// conv_cudnn_helper.h
auto x_dims = framework::vectorize(input->dims());
auto f_dims = framework::vectorize(filter->dims());
auto handle = dev_ctx.cudnn_handle();
......@@ -512,6 +495,212 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
}
};
/*
* Inputs: I, W, dO, ddI, ddW
* Outputs: ddO, dW, dI
* ddo = conv(ddI, W) + conv(I, ddW)
* dW = conv_bp_filter(ddI, dO)
* dI = conv_bp_data(ddW, dO)
*/
template <typename T>
class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
auto X = ctx.Input<Tensor>("Input");
auto W = ctx.Input<Tensor>("Filter");
auto dO = ctx.Input<Tensor>("DOutput");
auto ddX = ctx.Input<Tensor>("DDInput");
auto ddW = ctx.Input<Tensor>("DDFilter");
auto ddO = ctx.Output<Tensor>("DDOutput");
auto dW = ctx.Output<Tensor>("DFilter");
auto dX = ctx.Output<Tensor>("DInput");
const T* x = X->data<T>();
const T* dy = dO->data<T>();
const T* w = W->data<T>();
const T* ddx = nullptr;
const T* ddw = nullptr;
T *dw, *dx, *ddy;
dw = dx = ddy = nullptr;
const std::vector<int>& strides = ctx.Attr<std::vector<int>>("strides");
const std::vector<int>& paddings = ctx.Attr<std::vector<int>>("paddings");
const std::vector<int>& dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
bool deterministic = FLAGS_cudnn_deterministic;
if (exhaustive_search && deterministic) {
PADDLE_THROW(
"Cann't set exhaustive_search True and "
"FLAGS_cudnn_deterministic True at same time.");
}
int iwo_group = groups;
int c_group = 1;
#if CUDNN_VERSION_MIN(7, 0, 1)
iwo_group = 1;
c_group = groups;
#endif
auto dtype = platform::CudnnDataType<T>::type;
auto handle = dev_ctx.cudnn_handle();
ConvArgs args1{ddX, W, ddO, strides, paddings, dilations};
ConvArgs args2{X, ddW, ddO, strides, paddings, dilations};
ConvArgs args3{ddX, dW, dO, strides, paddings, dilations};
ConvArgs args4{dX, ddW, dO, strides, paddings, dilations};
cudnnConvolutionFwdAlgo_t fwd_algo1 =
static_cast<cudnnConvolutionFwdAlgo_t>(0);
cudnnConvolutionFwdAlgo_t fwd_algo2 =
static_cast<cudnnConvolutionFwdAlgo_t>(0);
cudnnConvolutionBwdDataAlgo_t data_algo =
static_cast<cudnnConvolutionBwdDataAlgo_t>(0);
cudnnConvolutionBwdFilterAlgo_t filter_algo =
static_cast<cudnnConvolutionBwdFilterAlgo_t>(0);
auto layout = GetCudnnTensorFormat(DataLayout::kNCHW);
// ddo = conv(ddI, W) + conv(I, ddW)
size_t workspace_size = 0;
if (ddO) {
ddy = ddO->mutable_data<T>(ctx.GetPlace());
args1.handle = handle;
args1.idesc.set(*ddX, iwo_group);
args1.wdesc.set(*W, layout, iwo_group);
args1.odesc.set(*ddO, iwo_group);
args1.cdesc.set(dtype, paddings, strides, dilations, c_group);
using search1 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, 0, ctx);
workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1);
if (ddW) {
ddw = ddW->data<T>();
args2.handle = handle;
args2.idesc.set(*X, iwo_group);
args2.wdesc.set(*ddW, layout, iwo_group);
args2.odesc.set(*ddO, iwo_group);
args2.cdesc.set(dtype, paddings, strides, dilations, c_group);
using search2 = SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, 0, ctx);
workspace_size = std::max(workspace_size,
search2::GetWorkspaceSize(args2, fwd_algo2));
}
}
if (dW) {
dw = dW->mutable_data<T>(ctx.GetPlace());
args3.handle = handle;
args3.idesc.set(*ddX, iwo_group);
args3.wdesc.set(*dW, layout, iwo_group);
args3.odesc.set(*dO, iwo_group);
args3.cdesc.set(dtype, paddings, strides, dilations, c_group);
using search3 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo =
search3::Find<T>(args3, exhaustive_search, deterministic, 1, ctx);
workspace_size = std::max(workspace_size,
search3::GetWorkspaceSize(args3, filter_algo));
}
if (ddW && dX) {
dx = dX->mutable_data<T>(ctx.GetPlace());
args4.handle = handle;
args4.idesc.set(*dX, iwo_group);
args4.wdesc.set(*ddW, layout, iwo_group);
args4.odesc.set(*dO, iwo_group);
args4.cdesc.set(dtype, paddings, strides, dilations, c_group);
using search4 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo =
search4::Find<T>(args4, exhaustive_search, deterministic, 2, ctx);
workspace_size =
std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo));
}
int i_n, i_c, i_d, i_h, i_w;
GetNCDHW(X->dims(), DataLayout::kNCHW, &i_n, &i_c, &i_d, &i_h, &i_w);
int o_n, o_c, o_d, o_h, o_w;
GetNCDHW(dO->dims(), DataLayout::kNCHW, &o_n, &o_c, &o_d, &o_h, &o_w);
int group_offset_in = i_c / groups * i_h * i_w * i_d;
int group_offset_out = o_c / groups * o_h * o_w * o_d;
int group_offset_filter = W->numel() / groups;
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto wkspace_handle = dev_ctx.cudnn_workspace_handle();
if (ddO) {
ddx = ddX->data<T>();
for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc(
[&](void* workspace_ptr) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, args1.idesc.desc(), ddx + i * group_offset_in,
args1.wdesc.desc(), w + i * group_offset_filter,
args1.cdesc.desc(), fwd_algo1, workspace_ptr, workspace_size,
&beta, args1.odesc.desc(), ddy + i * group_offset_out));
},
workspace_size);
}
if (ddW) {
for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc(
[&](void* workspace_ptr) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, args2.idesc.desc(), x + i * group_offset_in,
args2.wdesc.desc(), ddw + i * group_offset_filter,
args2.cdesc.desc(), fwd_algo2, workspace_ptr,
workspace_size, &alpha, args2.odesc.desc(),
ddy + i * group_offset_out));
},
workspace_size);
}
}
}
if (dW) {
ddx = ddX->data<T>();
for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc(
[&](void* workspace_ptr) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, args3.idesc.desc(), ddx + i * group_offset_in,
args3.odesc.desc(), dy + i * group_offset_out,
args3.cdesc.desc(), filter_algo, workspace_ptr,
workspace_size, &beta, args3.wdesc.desc(),
dw + i * group_offset_filter));
},
workspace_size);
}
}
if (dX && ddW) {
ddw = ddW->data<T>();
for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc(
[&](void* workspace_ptr) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, args4.wdesc.desc(),
ddw + i * group_offset_filter, args4.odesc.desc(),
dy + i * group_offset_out, args4.cdesc.desc(), data_algo,
workspace_ptr, workspace_size, &beta, args4.idesc.desc(),
dx + i * group_offset_in));
},
workspace_size);
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -524,6 +713,11 @@ REGISTER_OP_KERNEL(conv2d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>,
paddle::operators::CUDNNConvGradOpKernel<plat::float16>);
REGISTER_OP_KERNEL(
conv2d_grad_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvDoubleGradOpKernel<float>,
paddle::operators::CUDNNConvDoubleGradOpKernel<double>,
paddle::operators::CUDNNConvDoubleGradOpKernel<plat::float16>);
REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>,
......
......@@ -506,13 +506,100 @@ class Conv3DGradMaker : public framework::SingleGradOpDescMaker {
}
};
/*
* Inputs: I, W, dO, ddI, ddW
* Outputs: ddO, dW, dI
*/
class Conv2DDoubleGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType(this->ForwardOpType() + "_grad");
// I, W, dO, ddI, ddW
op->SetInput("Input", Input("Input"));
op->SetInput("Filter", Input("Filter"));
op->SetInput("DOutput", Input(framework::GradVarName("Output")));
op->SetInput("DDInput", OutputGrad(framework::GradVarName("Input")));
op->SetInput("DDFilter", OutputGrad(framework::GradVarName("Filter")));
// ddO, dI, dW
// Unlike grad op, double grad op does not use name@GRAD@GRAD
// as key of ops' inputs and outputs.
op->SetOutput("DDOutput", InputGrad(framework::GradVarName("Output")));
op->SetOutput("DFilter", InputGrad("Filter"));
op->SetOutput("DInput", InputGrad("Input"));
op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(op);
}
};
void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const {
auto x_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("Filter");
auto do_dims = ctx->GetInputDim("DOutput");
if (ctx->HasOutput("DDOutput")) {
ctx->SetOutputDim("DDOutput", do_dims);
}
if (ctx->HasOutput("DFilter")) {
ctx->SetOutputDim("DFilter", w_dims);
}
if (ctx->HasOutput("DInput")) {
ctx->SetOutputDim("DInput", x_dims);
}
}
framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
} else {
PADDLE_THROW("Now ConvDoubleGrad only supports cuDNN.");
}
#endif
auto type = framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_,
customized_type_value);
#ifdef PADDLE_WITH_CUDA
if (library_ == framework::LibraryType::kCUDNN) {
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
if (configs.empty()) {
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>> p0(
new framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>());
configs.push_back(p0);
std::shared_ptr<
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>
p1(new framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>());
configs.push_back(p1);
std::shared_ptr<framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>
p2(new framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>());
configs.push_back(p2);
}
}
#endif
return type;
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker,
ops::ConvOpInferVarType, ops::Conv2DGradMaker);
REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad);
REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad, ops::Conv2DDoubleGradMaker);
REGISTER_OPERATOR(conv2d_grad_grad, ops::ConvOpDoubleGrad);
// depthwise convolution op
REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker,
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -107,6 +108,16 @@ class ConvOpGrad : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override;
};
class ConvOpDoubleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
template <typename DeviceContext, typename T>
class GemmConvKernel : public framework::OpKernel<T> {
public:
......
......@@ -29,13 +29,14 @@ namespace platform {
using framework::Tensor;
template <typename T>
cudnnDataType_t ToCudnnDataType(const T& t) {
inline cudnnDataType_t ToCudnnDataType(const T& t) {
auto type = framework::ToDataType(t);
return ToCudnnDataType(type);
}
template <>
cudnnDataType_t ToCudnnDataType(const framework::proto::VarType::Type& t) {
inline cudnnDataType_t ToCudnnDataType(
const framework::proto::VarType::Type& t) {
cudnnDataType_t type = CUDNN_DATA_FLOAT;
switch (t) {
case framework::proto::VarType::FP16:
......@@ -59,14 +60,14 @@ class ActivationDescriptor {
struct Deleter {
void operator()(T* t) {
if (t != nullptr) {
PADDLE_ENFORCE(dynload::cudnnDestroyActivationDescriptor(t));
CUDNN_ENFORCE(dynload::cudnnDestroyActivationDescriptor(t));
t = nullptr;
}
}
};
ActivationDescriptor() {
T* raw_ptr;
PADDLE_ENFORCE(dynload::cudnnCreateActivationDescriptor(&raw_ptr));
CUDNN_ENFORCE(dynload::cudnnCreateActivationDescriptor(&raw_ptr));
desc_.reset(raw_ptr);
}
template <typename T>
......@@ -88,14 +89,14 @@ class TensorDescriptor {
struct Deleter {
void operator()(T* t) {
if (t != nullptr) {
PADDLE_ENFORCE(dynload::cudnnDestroyTensorDescriptor(t));
CUDNN_ENFORCE(dynload::cudnnDestroyTensorDescriptor(t));
t = nullptr;
}
}
};
TensorDescriptor() {
T* raw_ptr;
PADDLE_ENFORCE(dynload::cudnnCreateTensorDescriptor(&raw_ptr));
CUDNN_ENFORCE(dynload::cudnnCreateTensorDescriptor(&raw_ptr));
desc_.reset(raw_ptr);
}
T* desc() { return desc_.get(); }
......@@ -111,7 +112,7 @@ class TensorDescriptor {
if (groups > 1) {
dims_with_group[1] = dims_with_group[1] / groups;
}
PADDLE_ENFORCE(dynload::cudnnSetTensorNdDescriptor(
CUDNN_ENFORCE(dynload::cudnnSetTensorNdDescriptor(
desc_.get(), ToCudnnDataType(tensor.type()), dims_with_group.size(),
dims_with_group.data(), strides.data()));
}
......@@ -120,5 +121,83 @@ class TensorDescriptor {
std::unique_ptr<T, Deleter> desc_;
};
class FilterDescriptor {
public:
using T = cudnnFilterStruct;
struct Deleter {
void operator()(T* t) {
if (t != nullptr) {
CUDNN_ENFORCE(dynload::cudnnDestroyFilterDescriptor(t));
t = nullptr;
}
}
};
FilterDescriptor() {
T* raw_ptr;
CUDNN_ENFORCE(dynload::cudnnCreateFilterDescriptor(&raw_ptr));
desc_.reset(raw_ptr);
}
T* desc() { return desc_.get(); }
T* desc() const { return desc_.get(); }
void set(const Tensor& tensor, const cudnnTensorFormat_t format,
const int groups = 1) {
auto dims = framework::vectorize2int(tensor.dims());
if (groups > 1) {
dims[1] = dims[1] / groups;
}
CUDNN_ENFORCE(dynload::cudnnSetFilterNdDescriptor(
desc_.get(), ToCudnnDataType(tensor.type()), format, dims.size(),
dims.data()));
}
private:
std::unique_ptr<T, Deleter> desc_;
};
class ConvolutionDescriptor {
public:
using T = cudnnConvolutionStruct;
struct Deleter {
void operator()(T* t) {
if (t != nullptr) {
CUDNN_ENFORCE(dynload::cudnnDestroyConvolutionDescriptor(t));
t = nullptr;
}
}
};
ConvolutionDescriptor() {
T* raw_ptr;
CUDNN_ENFORCE(dynload::cudnnCreateConvolutionDescriptor(&raw_ptr));
desc_.reset(raw_ptr);
}
T* desc() { return desc_.get(); }
T* desc() const { return desc_.get(); }
void set(cudnnDataType_t dtype, const std::vector<int>& pads,
const std::vector<int>& strides, const std::vector<int>& dilations,
const int groups = 1) {
cudnnDataType_t compute_type =
(dtype == CUDNN_DATA_DOUBLE) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT;
T* desc = desc_.get();
CUDNN_ENFORCE(dynload::cudnnSetConvolutionNdDescriptor(
desc, pads.size(), pads.data(), strides.data(), dilations.data(),
CUDNN_CROSS_CORRELATION, compute_type));
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
desc, CUDNN_DEFAULT_MATH));
#if CUDNN_VERSION_MIN(7, 0, 1)
CUDNN_ENFORCE(
platform::dynload::cudnnSetConvolutionGroupCount(desc, groups));
if (dtype == CUDNN_DATA_HALF) {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
desc, CUDNN_TENSOR_OP_MATH));
}
#endif
}
private:
std::unique_ptr<T, Deleter> desc_;
};
} // namespace platform
} // namespace paddle
......@@ -472,6 +472,9 @@ inline std::string TensorDTypeToPyDTypeStr(
} // namespace details
inline py::array TensorToPyArray(const framework::Tensor &tensor) {
if (!tensor.IsInitialized()) {
return py::array();
}
bool is_gpu_tensor = platform::is_gpu_place(tensor.place());
const auto &tensor_dims = tensor.dims();
auto tensor_dtype = tensor.type();
......
......@@ -119,7 +119,10 @@ def as_numpy(tensor):
They can not be completely cast to Python ndarray. \
Please set the parameter 'return_numpy' as 'False' to \
return LoDTensor itself directly.")
return np.array(tensor)
if tensor._is_initialized():
return np.array(tensor)
else:
return None
def has_feed_operators(block, feed_targets, feed_holder_name):
......
......@@ -82,6 +82,10 @@ def set_var_in_scope(scope, place, name, value, recursive_seq_len=None):
return t
def var_to_np_array_in_scope(scope, place, name):
return np.array(scope.var(name).get_tensor())
def make_jacobian(x, y_size, np_dtype):
if isinstance(x, fluid.framework.Variable):
return np.zeros((_product(x.shape), y_size), dtype=np_dtype)
......@@ -192,14 +196,18 @@ def _compute_analytical_jacobian(program, x, y, place, scope):
x = _as_list(x)
jacobian = make_jacobian(x, y_size, np_type)
dx = _as_list(dx)
for i in six.moves.xrange(y_size):
_set_item(dy_t, i, 1, np_type)
dx_res = exe.run(program, scope=scope, fetch_list=dx)
for j in six.moves.xrange(len(x)):
jacobian[j][:, i] = dx_res[j].flatten()
if dx_res[j] is not None:
jacobian[j][:, i] = dx_res[j].flatten()
else:
jacobian[j][:, i] = np.zeros(
dx[j].shape, dtype=np_type).flatten()
_set_item(dy_t, i, 0, np_type)
return jacobian
......@@ -242,6 +250,7 @@ def grad_check(x,
# check input arguments
x = _as_list(x)
y = _as_list(y)
for v in x:
v.stop_gradient = False
v.persistable = True
......@@ -274,9 +283,24 @@ def grad_check(x,
]
# [y_idx, x_idx]
analytical = [
_compute_analytical_jacobian(program, x, yi, place, scope) for yi in y
]
analytical = []
for yi in y:
prog = program.clone()
clone_x = []
clone_y = None
for b in prog.blocks:
if b.has_var(yi.name):
clone_y = b.var(yi.name)
break
for xi in x:
for b in prog.blocks:
if b.has_var(xi.name):
clone_x.append(b.var(xi.name))
break
analytical.append(
_compute_analytical_jacobian(prog, clone_x, clone_y, place, scope))
for i, (x_idx,
y_idx) in enumerate(product(*[range(len(x)), range(len(y))])):
......@@ -334,6 +358,7 @@ def double_grad_check(x,
if y_grads is None:
scope = fluid.executor.global_scope()
y_grads = []
y_grads_init = []
for yi in y:
dyi_name = _append_grad_suffix_(yi.name)
np_type = dtype_to_np_dtype(yi.dtype)
......@@ -343,9 +368,20 @@ def double_grad_check(x,
v = np.random.random(size=yi.shape).astype(np_type)
set_var_in_scope(scope, place, dyi_name, v)
y_grads.append(dy)
y_grads_init.append(v)
else:
y_grads = _as_list(y_grads)
y_grads_init = [
var_to_np_array_in_scope(scope, place, v.name) for v in y_grads
]
# append first order grads
target_grads = calc_gradient(y, x, y_grads)
# y_grads are the input of first-order backward,
# so, they are also the input of second-order backward.
x += y_grads
x_init = _as_list(x_init)
x_init += y_grads_init
grad_check(x, target_grads, x_init, place, program, eps, atol, rtol)
......@@ -46,7 +46,6 @@ class TestMulGradCheck(unittest.TestCase):
class TestReluDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable shoule be clearly specified, not inlcude -1.
shape = [2, 8]
eps = 0.005
dtype = np.float64
......@@ -71,7 +70,6 @@ class TestReluDoubleGradCheck(unittest.TestCase):
class TestLeakyReluDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
# the shape of input variable shoule be clearly specified, not inlcude -1.
shape = [3, 7]
eps = 0.005
alpha = 0.2
......@@ -79,6 +77,7 @@ class TestLeakyReluDoubleGradCheck(unittest.TestCase):
x = layers.data('x', shape, False, dtype)
x.persistable = True
y = layers.leaky_relu(x, alpha=alpha)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
x_arr[np.abs(x_arr) < 0.005] = 0.02
......@@ -90,8 +89,30 @@ class TestLeakyReluDoubleGradCheck(unittest.TestCase):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConvDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 4, 14, 16]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv2d(x, 4, 1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
if core.is_compiled_with_cuda():
places = [fluid.CUDAPlace(0)]
for p in places:
self.func(p)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册