未验证 提交 698698f2 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge branch 'develop' into fix_vlog

......@@ -45,7 +45,7 @@ IF(${CBLAS_PROVIDER} STREQUAL "MKLML")
ELSE()
MESSAGE(FATAL_ERROR "Should enable MKLML when build MKLDNN")
ENDIF()
SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result")
SET(MKLDNN_FLAG "-Wno-error=strict-overflow -Wno-error=unused-result -Wno-error=array-bounds")
SET(MKLDNN_FLAG "${MKLDNN_FLAG} -Wno-unused-result -Wno-unused-value")
SET(MKLDNN_CFLAG "${CMAKE_C_FLAGS} ${MKLDNN_FLAG}")
SET(MKLDNN_CXXFLAG "${CMAKE_CXX_FLAGS} ${MKLDNN_FLAG}")
......@@ -54,7 +54,7 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
GIT_TAG "64e03a1939e0d526aa8e9f2e3f7dc0ad8d372944"
GIT_TAG "21fb5f2af1dd14e132af4f1b79160977ee487818"
PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
......@@ -118,9 +118,10 @@ paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon',
paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0))
paddle.fluid.layers.roi_align ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None))
paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,))
paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR'))
paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR', None))
paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',))
paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.layers.resize_nearest ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'actual_shape'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,))
......
......@@ -101,6 +101,7 @@ Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) {
std::vector<std::string> passes;
passes.push_back("graph_viz_pass"); // add graphviz for debug.
#ifdef PADDLE_WITH_MKLDNN
if (use_mkldnn_) {
VLOG(30) << "Adding MKL-DNN placement pass";
......@@ -110,13 +111,13 @@ void Analyzer::Run(Argument* argument) {
// infer_clean_graph_pass should be the first default pass
// after mkldnn_placement_pass.
passes.push_back("infer_clean_graph_pass");
passes.push_back("graph_viz_pass"); // add graphviz for debug.
for (auto& pass : ir_passes_) {
if (!disabled_ir_passes_.count(pass)) {
passes.push_back(pass);
passes.push_back("graph_viz_pass"); // add graphviz for debug.
}
}
passes.push_back("graph_viz_pass");
argument->Set(kFluidToIrPassesAttr, new std::vector<std::string>(passes));
for (auto& x : data_) {
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#pragma once
#include <algorithm>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/naive_executor.h"
......
......@@ -59,7 +59,8 @@ void ReadBinaryFile(const std::string& filename, std::string* contents) {
bool IsPersistable(const framework::VarDesc* var) {
if (var->Persistable() &&
var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
var->GetType() != framework::proto::VarType::FETCH_LIST) {
var->GetType() != framework::proto::VarType::FETCH_LIST &&
var->GetType() != framework::proto::VarType::RAW) {
return true;
}
return false;
......
......@@ -134,7 +134,7 @@ class TensorRTEngine : public EngineBase {
std::unordered_map<std::string /*name*/, std::unique_ptr<framework::Tensor>>
weight_map;
// TODO: (NHZLX)
// TODO(NHZLX)
// In the normal case, the paddle-trt exists bug when runing the googlenet.
// When there are more than two convolutions of 1 * 1 with the same input, the
// paddle-tensorrt will do the merging optimization, which fuse those conv
......
......@@ -66,9 +66,10 @@ class AddPositionEncodingKernel : public framework::OpKernel<T> {
x_lod.empty() ? max_seq_len : x_lod[0][i + 1] - x_lod[0][i];
for (int j = 0; j < max_length; ++j) {
for (int k = 0; k < half_size; ++k) {
const double val = (half_size > 1)
? j / pow(10000.0, double(k) / (half_size - 1))
: j / 10000.0;
const double val =
(half_size > 1)
? j / pow(10000.0, static_cast<double>(k) / (half_size - 1))
: j / 10000.0;
dst_ptr[k] = src_ptr[k] * alpha + sin(val) * beta;
dst_ptr[half_size + k] =
src_ptr[half_size + k] * alpha + cos(val) * beta;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class BilinearInterpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input_t = ctx.Input<Tensor>("X"); // float tensor
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
auto out_dims = output_t->dims();
auto* input = input_t->data<T>();
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_size_t = ctx.Input<Tensor>("OutSize");
if (out_size_t != nullptr) {
auto out_size_data = out_size_t->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
auto* output = output_t->mutable_data<T>(
{out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace());
int batch_size = input_t->dims()[0];
int channels = input_t->dims()[1];
int in_h = input_t->dims()[2];
int in_w = input_t->dims()[3];
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = channels * in_hw;
int out_chw = channels * out_hw;
float ratio_h =
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) {
memcpy(output, input, input_t->numel() * sizeof(T));
} else {
for (int k = 0; k < batch_size; ++k) { // loop for batches
for (int i = 0; i < out_h; ++i) { // loop for images
int h = ratio_h * i;
int hid = (h < in_h - 1) ? 1 : 0;
float h1lambda = ratio_h * i - h;
float h2lambda = 1.f - h1lambda;
for (int j = 0; j < out_w; ++j) {
int w = ratio_w * j;
int wid = (w < in_w - 1) ? 1 : 0;
float w1lambda = ratio_w * j - w;
float w2lambda = 1.f - w1lambda;
// calculate four position for bilinear interpolation
const T* in_pos = &input[k * in_chw + h * in_w + w];
T* out_pos = &output[k * out_chw + i * out_w + j];
for (int c = 0; c < channels; ++c) { // loop for channels
// bilinear interpolation
out_pos[0] = static_cast<T>(
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) +
h1lambda * (w2lambda * in_pos[hid * in_w] +
w1lambda * in_pos[hid * in_w + wid]));
in_pos += in_hw;
out_pos += out_hw;
}
}
}
}
}
}
};
template <typename T>
class BilinearInterpGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* d_output = d_output_t->data<T>();
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
auto& device_ctx =
ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, d_input_t, static_cast<T>(0.0));
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_size_t = ctx.Input<Tensor>("OutSize");
if (out_size_t != nullptr) {
auto out_size_data = out_size_t->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
int batch_size = d_input_t->dims()[0];
int channels = d_input_t->dims()[1];
int in_h = d_input_t->dims()[2];
int in_w = d_input_t->dims()[3];
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = channels * in_hw;
int out_chw = channels * out_hw;
float ratio_h =
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) {
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
} else {
for (int k = 0; k < batch_size; ++k) { // loop for batches
for (int i = 0; i < out_h; ++i) { // loop for images
int h = ratio_h * i;
int hid = (h < in_h - 1) ? 1 : 0;
float h1lambda = ratio_h * i - h;
float h2lambda = 1 - h1lambda;
for (int j = 0; j < out_w; ++j) {
int w = ratio_w * j;
int wid = (w < in_w - 1) ? 1 : 0;
float w1lambda = ratio_w * j - w;
float w2lambda = 1 - w1lambda;
T* in_pos = &d_input[k * in_chw + h * in_w + w];
const T* out_pos = &d_output[k * out_chw + i * out_w + j];
for (int c = 0; c < channels; ++c) { // loop for channels
in_pos[0] += static_cast<T>(h2lambda * w2lambda * out_pos[0]);
in_pos[wid] += static_cast<T>(h2lambda * w1lambda * out_pos[0]);
in_pos[hid * in_w] +=
static_cast<T>(h1lambda * w2lambda * out_pos[0]);
in_pos[hid * in_w + wid] +=
static_cast<T>(h1lambda * w1lambda * out_pos[0]);
in_pos += in_hw;
out_pos += out_hw;
}
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -15,15 +15,22 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_bool(cudnn_deterministic, false,
"Whether allow using an autotuning algorithm for convolution "
"operator. The autotuning algorithm may be non-deterministic. If "
"true, the algorithm is deterministic.");
DEFINE_uint64(conv_workspace_size_limit, 4096,
"cuDNN convolution workspace limit in MB unit.");
DEFINE_bool(cudnn_exhaustive_search, false,
"Whether enable exhaustive search for cuDNN convolution or "
"not, defalut is False.");
namespace paddle {
namespace operators {
......@@ -36,13 +43,25 @@ using DataLayout = platform::DataLayout;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
static constexpr char kCUDNNFwdAlgoCache[] = "kCUDNNFwdAlgoCache";
static constexpr char kCUDNNBwdDataAlgoCache[] = "kCUDNNBwdDataAlgoCache";
static constexpr char kCUDNNBwdFilterAlgoCache[] = "kCUDNNBwdFilterAlgoCache";
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
static_cast<size_t>(1024) * 1024 * 1024;
static constexpr size_t kNUM_CUDNN_FWD_ALGS =
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS =
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS =
CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
template <typename T>
class CUDNNConvOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
auto* input = ctx.Input<Tensor>("Input");
......@@ -55,6 +74,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
int groups = ctx.Attr<int>("groups");
int64_t user_workspace_size =
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
......@@ -120,19 +141,19 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv workspace ---------------------
size_t workspace_size_in_bytes; // final workspace to allocate.
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024;
if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
int64_t max_user_size =
std::max(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
user_workspace_size);
workspace_size_limit = max_user_size * 1024 * 1024;
}
// ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionFwdAlgo_t algo;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
bool half_float = false;
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
// Tensor core is supported since the volta GPU and
// is only enabled when input and filter data are float16
......@@ -143,6 +164,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
cudnn_conv_desc, CUDNN_TENSOR_OP_MATH));
// Currently tensor core is only enabled using this algo
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
half_float = true;
VLOG(50) << "use cudnn_tensor_op_math";
} else {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
......@@ -151,6 +173,57 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
}
#endif
auto x_dims = framework::vectorize(input->dims());
auto f_dims = framework::vectorize(filter->dims());
if ((!exhaustive_search) && (!half_float)) {
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo));
VLOG(3) << "cuDNN forward algo " << algo;
} else if (exhaustive_search && (!half_float)) {
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* algo_cache = nullptr;
if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) {
algo_cache =
ctx.scope()
.FindVar(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
} else {
algo_cache =
const_cast<framework::Scope&>(ctx.scope())
.Var(kCUDNNFwdAlgoCache)
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
}
algo = algo_cache->GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
fwd_perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc,
output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
fwd_perf_stat.data(), cudnn_workspace,
workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = fwd_perf_stat[i];
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
<< " " << stat.memory;
}
return fwd_perf_stat[0].algo;
});
VLOG(3) << "choose algo " << algo;
} else {
PADDLE_ENFORCE(half_float,
"cuDNN exhaustive search doesn't support half float.");
}
// get workspace size able to allocate
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
......@@ -162,7 +235,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv forward ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
for (int i = 0; i < groups; i++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
......@@ -180,6 +252,7 @@ template <typename T>
class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
auto input = ctx.Input<Tensor>("Input");
......@@ -198,6 +271,13 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
int groups = ctx.Attr<int>("groups");
int64_t user_workspace_size =
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
if (exhaustive_search && FLAGS_cudnn_deterministic) {
PADDLE_THROW(
"Cann't set exhaustive_search True and "
"FLAGS_cudnn_deterministic True at same time.");
}
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
......@@ -265,14 +345,66 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
cudnnConvolutionBwdFilterAlgo_t filter_algo;
size_t workspace_size_in_bytes = 0, tmp_size = 0;
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024;
if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
int64_t max_user_size =
std::max(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
user_workspace_size);
workspace_size_limit = max_user_size * 1024 * 1024;
}
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto x_dims = framework::vectorize(input->dims());
auto f_dims = framework::vectorize(filter->dims());
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
if (input_grad) {
if (!FLAGS_cudnn_deterministic) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
if (exhaustive_search) {
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>* data_algo_cache;
if (ctx.scope().FindVar(kCUDNNBwdDataAlgoCache)) {
data_algo_cache =
ctx.scope()
.FindVar(kCUDNNBwdDataAlgoCache)
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>();
} else {
data_algo_cache =
const_cast<framework::Scope&>(ctx.scope())
.Var(kCUDNNBwdDataAlgoCache)
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>();
}
data_algo = data_algo_cache->GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionBwdDataAlgoPerf_t,
kNUM_CUDNN_BWD_DATA_ALGS>
data_perf_stat;
auto cudnn_find_bd_data_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(
platform::dynload::
cudnnFindConvolutionBackwardDataAlgorithmEx(
handle, cudnn_filter_desc, filter_data,
cudnn_output_grad_desc, output_grad_data,
cudnn_conv_desc, cudnn_input_desc, input_grad_data,
kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count,
data_perf_stat.data(), cudnn_workspace,
workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_bd_data_func,
workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = data_perf_stat[i];
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
<< " " << stat.memory;
}
return data_perf_stat[0].algo;
});
VLOG(3) << "cuDNN backward data algo " << data_algo;
} else if (FLAGS_cudnn_deterministic) {
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
} else {
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc,
......@@ -285,10 +417,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
cudnn_input_desc,
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &data_algo));
} else {
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
}
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
handle, cudnn_filter_desc, cudnn_output_grad_desc,
......@@ -297,17 +426,54 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
}
if (filter_grad) {
if (!FLAGS_cudnn_deterministic) {
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
if (exhaustive_search) {
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>* f_algo_cache;
if (ctx.scope().FindVar(kCUDNNBwdFilterAlgoCache)) {
f_algo_cache =
ctx.scope()
.FindVar(kCUDNNBwdFilterAlgoCache)
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>();
} else {
f_algo_cache =
const_cast<framework::Scope&>(ctx.scope())
.Var(kCUDNNBwdFilterAlgoCache)
->GetMutable<
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>();
}
filter_algo = f_algo_cache->GetAlgorithm(
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
int returned_algo_count;
std::array<cudnnConvolutionBwdFilterAlgoPerf_t,
kNUM_CUDNN_BWD_FILTER_ALGS>
filter_perf_stat;
auto cudnn_find_bd_f_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(
platform::dynload::
cudnnFindConvolutionBackwardFilterAlgorithmEx(
handle, cudnn_input_desc, input_data,
cudnn_output_grad_desc, output_grad_data,
cudnn_conv_desc, cudnn_filter_desc,
filter_grad_data, kNUM_CUDNN_BWD_FILTER_ALGS,
&returned_algo_count, filter_perf_stat.data(),
cudnn_workspace, workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_bd_f_func,
workspace_size_limit);
return filter_perf_stat[0].algo;
});
VLOG(3) << "cuDNN backward filter algo " << filter_algo;
} else if (FLAGS_cudnn_deterministic) {
filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
} else {
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_input_desc, cudnn_output_grad_desc,
cudnn_conv_desc, cudnn_filter_desc,
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &filter_algo));
} else {
filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
}
CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
......@@ -317,7 +483,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <unordered_map>
#include <vector>
namespace paddle {
namespace operators {
template <typename TAlgorithm>
class AlgorithmsCache {
public:
// Caches the best algorithm for a given
// combination of tensor dimensions & compute data type.
TAlgorithm GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations,
int algorithmFlags, // can set for different data type
std::function<TAlgorithm()> gen_func);
private:
std::unordered_map<int64_t, TAlgorithm> hash_;
std::mutex mutex_;
};
template <typename TAlgorithm>
TAlgorithm AlgorithmsCache<TAlgorithm>::GetAlgorithm(
const std::vector<int64_t>& dims1, const std::vector<int64_t>& dims2,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::vector<int>& dilations, int algorithmFlags,
std::function<TAlgorithm()> gen_func) {
std::lock_guard<std::mutex> lock(mutex_);
int64_t seed = 0;
// Hash all of the inputs, use to try and look up a previously
// discovered algorithm, or fall back to generating a new one.
std::hash<int64_t> hashFn;
// do hash like boost
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
for (const auto num : dims1) {
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
for (const auto num : dims2) {
seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1;
}
for (const auto num : strides) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 2;
}
for (const auto num : paddings) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 3;
}
for (const auto num : dilations) {
seed ^= hashFn(static_cast<int64_t>(num)) + 0x9e3779b9 + (seed << 6) +
(seed >> 2) + 4;
}
seed ^= hashFn(static_cast<int64_t>(algorithmFlags)) + 0x9e3779b9 +
(seed << 6) + (seed >> 2) + 5;
if (seed == 0) return gen_func();
if (hash_.find(seed) == hash_.end()) {
TAlgorithm value = gen_func();
hash_[seed] = value;
}
return hash_[seed];
}
} // namespace operators
} // namespace paddle
......@@ -375,8 +375,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(),
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
// Currently used whenever bias is != nullptr.
auto dst_md = platform::MKLDNNMemDesc(
......
......@@ -189,6 +189,11 @@ void Conv2DOpMaker::Make() {
"workspace size can increase performance but also requires "
"better hardware. This size should be chosen carefully.")
.SetDefault(4096);
AddAttr<bool>("exhaustive_search",
"(bool, default false) cuDNN has many algorithm to calculation "
"convolution, whether enable exhaustive search ",
"for cuDNN convolution or not, defalut is False.")
.SetDefault(false);
AddComment(R"DOC(
Convolution Operator.
......@@ -283,7 +288,11 @@ void Conv3DOpMaker::Make() {
"workspace size can increase performance but also requires "
"better hardware. This size should be chosen carefully.")
.SetDefault(4096);
AddAttr<bool>("exhaustive_search",
"(bool, default false) cuDNN has many algorithm to calculation "
"convolution, whether enable exhaustive search ",
"for cuDNN convolution or not, defalut is False.")
.SetDefault(false);
AddComment(R"DOC(
Convolution3D Operator.
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
......@@ -9,7 +9,8 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bilinear_interp_op.h"
#include "paddle/fluid/operators/interpolate_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
......@@ -18,27 +19,34 @@ namespace operators {
using framework::Tensor;
class BilinearInterpOp : public framework::OperatorWithKernel {
class InterpolateOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of BilinearInterOp should not be null.");
"Input(X) of InterpolateOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of BilinearInterOp should not be null.");
"Output(Out) of InterpolationOp should not be null.");
auto interp_method = ctx->Attrs().Get<std::string>("interp_method");
PADDLE_ENFORCE(
"bilinear" == interp_method || "nearest" == interp_method,
"Interpolation method can only be \"bilinear\" or \"nearest\".");
auto dim_x = ctx->GetInputDim("X"); // NCHW format
int out_h = ctx->Attrs().Get<int>("out_h");
int out_w = ctx->Attrs().Get<int>("out_w");
PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
if (ctx->HasInput("OutSize")) {
if (ctx->HasInput("OutSize") && ctx->IsRuntime()) {
auto out_size_dim = ctx->GetInputDim("OutSize");
PADDLE_ENFORCE_EQ(out_size_dim.size(), 1,
"OutSize's dimension size must be 1");
PADDLE_ENFORCE_EQ(out_size_dim[0], 2, "OutSize's dim[0] must be 2");
ctx->ShareLoD("X", "Out");
return;
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
ctx->SetOutputDim("Out", framework::make_ddim(dim_out));
......@@ -52,35 +60,53 @@ class BilinearInterpOp : public framework::OperatorWithKernel {
}
};
class BilinearInterpOpMaker : public framework::OpProtoAndCheckerMaker {
class InterpolateOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input tensor of bilinear interpolation, "
"This is a 4-D tensor with shape of (N x C x h x w)");
"The input tensor of interpolate operator, "
"This is a 4-D tensor with shape of [N, C, H, w].");
AddInput("OutSize",
"This is a 1-D tensor with two number. "
"This is a 1-D tensor with two numbers to specify output size. "
"The first number is height and the second number is width.")
.AsDispensable();
AddOutput("Out", "The dimension of output is (N x C x out_h x out_w)");
AddOutput("Out",
"The output tensor of interpolate operator, "
"This is a 4-D tensor with shape of [N, C, H, W].");
AddAttr<int>("out_h", "output height of bilinear interpolation op.");
AddAttr<int>("out_w", "output width of bilinear interpolation op.");
AddAttr<int>("out_h", "output height of interpolate op.");
AddAttr<int>("out_w", "output width of interpolate op.");
AddAttr<std::string>(
"interp_method",
"(string), interpolation method, can be \"bilinear\" for "
"bilinear interpolation and \"nearest\" for nearest "
"neighbor interpolation.");
AddComment(R"DOC(
This operator samples input X to given output shape by using specified
interpolation method, the interpolation methods can be \"nearest\"
for nearest neighbor interpolation and \"bilinear\" for bilinear
interpolation.
Nearest neighbor interpolation is to perform nearest neighbor interpolation
in both the 3rd dimention(in height direction) and the 4th dimention(in width
direction) on input tensor.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this op) on a rectilinear 2D grid.
The key idea is to perform linear interpolation first in one
direction, and then again in the other direction.
For details, please refer to Wikipedia:
W-direction in this op) on a rectilinear 2D grid. The key idea is
to perform linear interpolation first in one direction, and then
again in the other direction.
For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
)DOC");
}
};
class BilinearInterpOpGrad : public framework::OperatorWithKernel {
class InterpolateOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -106,11 +132,11 @@ class BilinearInterpOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(bilinear_interp, ops::BilinearInterpOp,
ops::BilinearInterpOpMaker,
REGISTER_OPERATOR(interpolate, ops::InterpolateOp, ops::InterpolateOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(bilinear_interp_grad, ops::BilinearInterpOpGrad);
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::BilinearInterpKernel<float>,
ops::BilinearInterpKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(bilinear_interp_grad,
ops::BilinearInterpGradKernel<float>);
REGISTER_OPERATOR(interpolate_grad, ops::InterpolateOpGrad);
REGISTER_OP_CPU_KERNEL(interpolate, ops::InterpolateKernel<float>,
ops::InterpolateKernel<double>,
ops::InterpolateKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(interpolate_grad, ops::InterpolateGradKernel<float>,
ops::InterpolateGradKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
......@@ -9,7 +9,8 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bilinear_interp_op.h"
#include <string>
#include "paddle/fluid/operators/interpolate_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
......@@ -17,15 +18,72 @@ namespace operators {
using framework::Tensor;
template <typename T>
__global__ void KeNearestNeighborInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
int in_img_idy = static_cast<int>(ratio_h * out_img_idy + 0.5);
int out_img_idx = tid % out_img_w;
int in_img_idx = static_cast<int>(ratio_w * out_img_idx + 0.5);
out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
}
}
template <typename T>
__global__ void KeNearestNeighborInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_img_w;
int in_img_idy = static_cast<int>(ratio_h * out_img_idy + 0.5);
int out_img_idx = tid % out_img_w;
int in_img_idx = static_cast<int>(ratio_w * out_img_idx + 0.5);
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
const T out_pos = out[out_id_h * output_w + out_id_w];
platform::CudaAtomicAdd(in_pos, out_pos);
}
}
template <typename T>
__global__ void KeBilinearInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const T ratio_h, const T ratioW) {
const size_t num_channels, const float ratio_h, const float ratio_w) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) {
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
......@@ -39,9 +97,9 @@ __global__ void KeBilinearInterpFw(
T h2lambda = 1.f - h1lambda;
int out_img_idx = tid % out_img_w;
int in_img_idx = ratioW * out_img_idx;
int in_img_idx = ratio_w * out_img_idx;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T w1lambda = ratioW * out_img_idx - in_img_idx;
T w1lambda = ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
......@@ -60,10 +118,11 @@ __global__ void KeBilinearInterpBw(
T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h,
const size_t input_w, const T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const T ratio_h, const T ratioW) {
const size_t num_channels, const T ratio_h, const T ratio_w) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < nthreads) {
int stride = blockDim.x * gridDim.x;
for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
......@@ -77,122 +136,146 @@ __global__ void KeBilinearInterpBw(
T h2lambda = 1.f - h1lambda;
int out_img_idx = tid % out_img_w;
int in_img_idx = ratioW * out_img_idx;
int in_img_idx = ratio_w * out_img_idx;
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
T w1lambda = ratioW * out_img_idx - in_img_idx;
T w1lambda = ratio_w * out_img_idx - in_img_idx;
T w2lambda = 1.f - w1lambda;
T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
const T* out_pos = &out[out_id_h * output_w + out_id_w];
atomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
atomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]);
atomicAdd(&in_pos[h_id * in_img_w], h1lambda * w2lambda * out_pos[0]);
atomicAdd(&in_pos[h_id * in_img_w + w_id],
h1lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[w_id], h2lambda * w1lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[h_id * in_img_w],
h1lambda * w2lambda * out_pos[0]);
platform::CudaAtomicAdd(&in_pos[h_id * in_img_w + w_id],
h1lambda * w1lambda * out_pos[0]);
}
}
template <typename T>
class BilinearInterpOpCUDAKernel : public framework::OpKernel<T> {
class InterpolateOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"This kernel only runs on GPU device.");
auto* input_t = ctx.Input<Tensor>("X"); // float tensor
auto* output_t = ctx.Output<Tensor>("Out"); // float tensor
auto* input = input_t->data<T>();
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto* input_data = input->data<T>();
auto interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_dims = output_t->dims();
auto out_size_t = ctx.Input<Tensor>("OutSize");
if (out_size_t != nullptr) {
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes);
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
auto* output = output_t->mutable_data<T>(
{out_dims[0], out_dims[1], out_h, out_w}, ctx.GetPlace());
int batch_size = input_t->dims()[0];
int channels = input_t->dims()[1];
int in_h = input_t->dims()[2];
int in_w = input_t->dims()[3];
int n = input->dims()[0];
int c = input->dims()[1];
int in_h = input->dims()[2];
int in_w = input->dims()[3];
auto* output_data =
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = channels * in_hw;
int out_chw = channels * out_hw;
int in_chw = c * in_hw;
int out_chw = c * out_hw;
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
float ratio_h =
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) {
memcpy(output, input, input_t->numel() * sizeof(T));
} else {
int threadNum = batch_size * out_chw;
int blocks = (threadNum + 1024 - 1) / 1024;
framework::TensorCopy(*input, ctx.GetPlace(), output);
return;
}
int pixelNum = n * out_chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("nearest" == interp_method) {
KeNearestNeighborInterpFw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w);
} else if ("bilinear" == interp_method) {
KeBilinearInterpFw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
input, in_h, in_w, batch_size, in_chw, output, out_h, out_w,
batch_size, out_chw, channels, ratio_h, ratio_w);
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w);
}
}
};
template <typename T>
class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_input_t = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* d_output_t = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* d_output = d_output_t->data<T>();
auto* d_input = d_input_t->mutable_data<T>(ctx.GetPlace());
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* output_grad_data = output_grad->data<T>();
auto* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
auto& device_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> zero;
zero(device_ctx, d_input_t, static_cast<T>(0.0));
zero(device_ctx, input_grad, static_cast<T>(0.0));
auto interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_size_t = ctx.Input<Tensor>("OutSize");
if (out_size_t != nullptr) {
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Tensor sizes;
framework::TensorCopy(*out_size_t, platform::CPUPlace(), &sizes);
framework::TensorCopy(*out_size, platform::CPUPlace(), &sizes);
auto size_data = sizes.data<int>();
out_h = size_data[0];
out_w = size_data[1];
}
int batch_size = d_input_t->dims()[0];
int channels = d_input_t->dims()[1];
int in_h = d_input_t->dims()[2];
int in_w = d_input_t->dims()[3];
int n = input_grad->dims()[0];
int c = input_grad->dims()[1];
int in_h = input_grad->dims()[2];
int in_w = input_grad->dims()[3];
int in_hw = in_h * in_w;
int out_hw = out_h * out_w;
int in_chw = channels * in_hw;
int out_chw = channels * out_hw;
int in_chw = c * in_hw;
int out_chw = c * out_hw;
T ratio_h = (out_h > 1) ? static_cast<T>(in_h - 1) / (out_h - 1) : 0.f;
T ratio_w = (out_w > 1) ? static_cast<T>(in_w - 1) / (out_w - 1) : 0.f;
float ratio_h =
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) {
memcpy(d_input, d_output, d_input_t->numel() * sizeof(T));
} else {
int threadNum = batch_size * out_chw;
int blocks = (threadNum + 1024 - 1) / 1024;
framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad);
return;
}
int pixelNum = n * out_chw;
int grid_dim = (pixelNum + 512 - 1) / 512;
grid_dim = grid_dim > 8 ? 8 : grid_dim;
if ("nearest" == interp_method) {
KeNearestNeighborInterpBw<
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
out_w, n, out_chw, c, ratio_h, ratio_w);
} else if ("bilinear" == interp_method) {
KeBilinearInterpBw<
T><<<blocks, 1024, 0, ctx.cuda_device_context().stream()>>>(
d_input, in_h, in_w, batch_size, in_chw, d_output, out_h, out_w,
batch_size, out_chw, channels, ratio_h, ratio_w);
T><<<grid_dim, 512, 0, ctx.cuda_device_context().stream()>>>(
input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h,
out_w, n, out_chw, c, ratio_h, ratio_w);
}
}
};
......@@ -201,7 +284,9 @@ class BilinearInterpGradOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(bilinear_interp,
ops::BilinearInterpOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(bilinear_interp_grad,
ops::BilinearInterpGradOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(interpolate, ops::InterpolateOpCUDAKernel<float>,
ops::InterpolateOpCUDAKernel<double>,
ops::InterpolateOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(interpolate_grad,
ops::InterpolateGradOpCUDAKernel<float>,
ops::InterpolateGradOpCUDAKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Tensor = framework::Tensor;
template <typename T>
static void NearestNeighborInterpolate(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
const int n, const int c,
const int out_h, const int out_w) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = static_cast<int>(ratio_h * k + 0.5);
for (int l = 0; l < out_w; l++) {
int in_l = static_cast<int>(ratio_w * l + 0.5);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
output_t(i, j, k, l) = input_t(i, j, in_k, in_l);
}
}
}
}
}
template <typename T>
static void BilinearInterpolation(const Tensor& input, Tensor* output,
const float ratio_h, const float ratio_w,
const int in_h, const int in_w, const int n,
const int c, const int out_h,
const int out_w) {
auto input_t = EigenTensor<T, 4>::From(input);
auto output_t = EigenTensor<T, 4>::From(*output);
for (int k = 0; k < out_h; k++) { // loop for images
int y_n = static_cast<int>(ratio_h * k);
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float d_n = ratio_h * k - y_n;
float d_s = 1.f - d_n;
for (int l = 0; l < out_w; l++) {
int x_w = static_cast<int>(ratio_w * l);
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float d_w = ratio_w * l - x_w;
float d_e = 1.f - d_w;
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// bilinear interpolation
output_t(i, j, k, l) = input_t(i, j, y_n, x_w) * d_s * d_e +
input_t(i, j, y_s, x_w) * d_n * d_e +
input_t(i, j, y_n, x_e) * d_s * d_w +
input_t(i, j, y_s, x_e) * d_n * d_w;
}
}
}
}
}
template <typename T>
static void NearestNeighborInterpolateGrad(const Tensor& output_grad,
Tensor* input_grad,
const float ratio_h,
const float ratio_w, const int n,
const int c, const int out_h,
const int out_w) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int k = 0; k < out_h; k++) { // loop for images
int in_k = static_cast<int>(ratio_h * k + 0.5);
for (int l = 0; l < out_w; l++) {
int in_l = static_cast<int>(ratio_w * l + 0.5);
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l);
}
}
}
}
}
template <typename T>
static void BilinearInterpolationGrad(const Tensor& output_grad,
Tensor* input_grad, const float ratio_h,
const float ratio_w, const int in_h,
const int in_w, const int n, const int c,
const int out_h, const int out_w) {
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int k = 0; k < out_h; k++) { // loop for images
int y_n = static_cast<int>(ratio_h * k);
int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1);
float d_n = ratio_h * k - y_n;
float d_s = 1.f - d_n;
for (int l = 0; l < out_w; l++) {
int x_w = static_cast<int>(ratio_w * l);
int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1);
float d_w = ratio_w * l - x_w;
float d_e = 1.f - d_w;
for (int i = 0; i < n; i++) { // loop for batches
for (int j = 0; j < c; j++) { // loop for channels
// bilinear interpolation grad
const T grad = output_grad_t(i, j, k, l);
input_grad_t(i, j, y_n, x_w) += static_cast<T>(grad * d_s * d_e);
input_grad_t(i, j, y_s, x_w) += static_cast<T>(grad * d_n * d_e);
input_grad_t(i, j, y_n, x_e) += static_cast<T>(grad * d_s * d_w);
input_grad_t(i, j, y_s, x_e) += static_cast<T>(grad * d_n * d_w);
}
}
}
}
}
template <typename T>
class InterpolateKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
std::string interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = out_size->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
output->mutable_data<T>({n, c, out_h, out_w}, ctx.GetPlace());
auto& device_ctx =
ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, output, static_cast<T>(0.0));
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(*input, ctx.GetPlace(), output);
return;
}
float ratio_h =
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if ("bilinear" == interp_method) {
BilinearInterpolation<T>(*input, output, ratio_h, ratio_w, in_h, in_w, n,
c, out_h, out_w);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolate<T>(*input, output, ratio_h, ratio_w, n, c,
out_h, out_w);
}
}
};
template <typename T>
class InterpolateGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
std::string interp_method = ctx.Attr<std::string>("interp_method");
int out_h = ctx.Attr<int>("out_h");
int out_w = ctx.Attr<int>("out_w");
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
auto out_size_data = out_size->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
const int n = input->dims()[0];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
input_grad->mutable_data<T>({n, c, in_h, in_w}, ctx.GetPlace());
auto& device_ctx =
ctx.template device_context<platform::CPUDeviceContext>();
math::SetConstant<platform::CPUDeviceContext, T> zero;
zero(device_ctx, input_grad, static_cast<T>(0.0));
if (in_h == out_h && in_w == out_w) {
framework::TensorCopy(*output_grad, ctx.GetPlace(), input_grad);
return;
}
float ratio_h =
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if ("bilinear" == interp_method) {
BilinearInterpolationGrad<T>(*output_grad, input_grad, ratio_h, ratio_w,
in_h, in_w, n, c, out_h, out_w);
} else if ("nearest" == interp_method) {
NearestNeighborInterpolateGrad<T>(*output_grad, input_grad, ratio_h,
ratio_w, n, c, out_h, out_w);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -24,21 +24,30 @@ namespace gen {
using namespace platform::jit; // NOLINT
bool VVVJitCode::init(int d) {
bool VXXJitCode::init(int d, int scalar_index) {
// It's not necessary to use avx512 since it would slow down the frequency
// and this kernel is not compute bound.
return MayIUse(avx);
return MayIUse(avx) && scalar_index >= 0 && scalar_index <= 2;
}
void VVVJitCode::generate() {
void VXXJitCode::generate() {
// do not need push stack, and do not need save avx512reg if do not use avx512
int offset = 0;
if (with_relu_) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
if (scalar_index_ == 1) {
vbroadcastss(ymm_src1, ptr[param1]);
} else if (scalar_index_ == 2) {
vbroadcastss(ymm_src2, ptr[param2]);
}
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src1, ptr[param1 + offset]);
vmovups(ymm_src2, ptr[param2 + offset]);
if (scalar_index_ != 1) {
vmovups(ymm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(ymm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) {
vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::add) {
......@@ -52,8 +61,12 @@ void VVVJitCode::generate() {
}
int rest = num_ % AVX_FLOAT_BLOCK;
if (rest >= 4) {
vmovups(xmm_src1, ptr[param1 + offset]);
vmovups(xmm_src2, ptr[param2 + offset]);
if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
......@@ -67,8 +80,12 @@ void VVVJitCode::generate() {
rest -= 4;
}
if (rest >= 2) {
vmovq(xmm_src1, ptr[param1 + offset]);
vmovq(xmm_src2, ptr[param2 + offset]);
if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) {
vmulps(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
......@@ -82,8 +99,12 @@ void VVVJitCode::generate() {
rest -= 2;
}
if (rest > 0) {
vmovss(xmm_src1, ptr[param1 + offset]);
vmovss(xmm_src2, ptr[param2 + offset]);
if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) {
vmulss(xmm_dst, xmm_src1, xmm_src2);
} else if (type_ == operand_type::add) {
......@@ -96,6 +117,7 @@ void VVVJitCode::generate() {
}
ret();
}
} // namespace gen
} // namespace jitkernel
} // namespace math
......
......@@ -29,33 +29,46 @@ using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label;
// function: vec = Operand(vec, vec) (maybe with relu)
typedef enum { mul = 0, add } operand_type;
class VVVJitCode : public JitCode {
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class VXXJitCode : public JitCode {
public:
const char* name() const override {
std::string base = "VVVJitCode";
std::string base = "VXXJitCode";
if (scalar_index_ == 1) {
base += "_Scalar";
} else {
base += "_Vec";
}
if (type_ == operand_type::mul) {
base += "_Mul";
} else if (type_ == operand_type::add) {
base += "_Add";
}
base += (with_relu_ ? "_relu" : "");
if (scalar_index_ == 2) {
base += "_Scalar";
} else {
base += "_Vec";
}
base += (with_relu_ ? "_Relu" : "");
return base.c_str();
}
explicit VVVJitCode(int d, operand_type type, bool with_relu,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
explicit VXXJitCode(int d, operand_type type, int scalar_index,
bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {}
static bool init(int d);
static bool init(int d, int scalar_index = 0);
void generate() override;
private:
int num_;
operand_type type_;
int scalar_index_;
bool with_relu_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
......@@ -63,13 +76,13 @@ class VVVJitCode : public JitCode {
xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(1);
xmm_t xmm_zero = xmm_t(2);
xmm_t xmm_dst = xmm_t(2);
xmm_t xmm_zero = xmm_t(3);
ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(1);
ymm_t ymm_zero = ymm_t(2);
ymm_t ymm_dst = ymm_t(2);
ymm_t ymm_zero = ymm_t(3);
};
} // namespace gen
......
......@@ -83,14 +83,15 @@ class VAddReluKernel : public Kernel {
template <typename T>
class VScalKernel : public Kernel {
public:
virtual void Compute(const T a, const T *x, T *y) const = 0;
virtual void Compute(const T a, T *x) const = 0;
// y = a.*x
void (*Compute)(const T *, const T *, T *, int);
};
template <typename T>
class VAddBiasKernel : public Kernel {
public:
virtual void Compute(const T a, const T *x, T *y) const = 0;
// y = a.+x
void (*Compute)(const T *, const T *, T *, int);
};
template <typename T>
......
......@@ -57,6 +57,20 @@ void VAddReluRefer(const T* x, const T* y, T* z, int n) {
}
}
template <typename T>
void VScalRefer(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] * x[i];
}
}
template <typename T>
void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = a[0] + x[i];
}
}
#ifdef PADDLE_WITH_MKLML
template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n);
......@@ -83,6 +97,28 @@ template <>
void VAddMKL<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdAdd(n, x, y, z);
}
template <typename T>
void VScalMKL(const T* a, const T* x, T* y, int n);
template <>
void VScalMKL<float>(const float* a, const float* x, float* y, int n) {
if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, 1);
} else {
VScalRefer<float>(a, x, y, n);
}
}
template <>
void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, 1);
} else {
VScalRefer<double>(a, x, y, n);
}
}
#endif
#define DECLARE_STATIC_FUNC \
......@@ -102,7 +138,7 @@ class VMulKernelImpl : public VMulKernel<T> {
if (useJIT(d)) {
// roughly estimate the size of code
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::mul, false,
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
......@@ -121,14 +157,14 @@ class VMulKernelImpl : public VMulKernel<T> {
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VMulKernelImpl<float>::useJIT(int d) {
return gen::VVVJitCode::init(d);
return gen::VXXJitCode::init(d);
}
#endif
......@@ -153,7 +189,7 @@ class VAddKernelImpl : public VAddKernel<T> {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, false,
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
......@@ -171,14 +207,14 @@ class VAddKernelImpl : public VAddKernel<T> {
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VAddKernelImpl<float>::useJIT(int d) {
return gen::VVVJitCode::init(d);
return gen::VXXJitCode::init(d);
}
#endif
......@@ -203,7 +239,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VVVJitCode(d, gen::operand_type::add, true,
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
......@@ -215,148 +251,106 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VVVJitCode> jitcode_{nullptr};
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VAddReluKernelImpl<float>::useJIT(int d) {
return gen::VVVJitCode::init(d);
return gen::VXXJitCode::init(d);
}
#endif
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
/* VSCAL JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
/* VScal JitKernel */
template <typename T>
class VScalKernelImpl : public VScalKernel<T> {
public:
explicit VScalKernelImpl(int d) : VScalKernel<T>() { this->num_ = d; }
void Compute(const T a, const T* x, T* y) const override {
for (int i = 0; i < this->num_; ++i) {
y[i] = a * x[i];
}
}
void Compute(const T a, T* x) const override {
for (int i = 0; i < this->num_; ++i) {
x[i] = a * x[i];
DECLARE_STATIC_FUNC;
explicit VScalKernelImpl(int d) : VScalKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
}
};
#endif
#ifdef PADDLE_WITH_MKLML
#define MKL_FLOAT(isa, block) \
template <> \
void VScalKernelImpl<float, isa, block>::Compute(const float a, float* x) \
const { \
platform::dynload::cblas_sscal(this->num_, a, x, 1); \
}
#define MKL_DOUBLE(isa, block) \
template <> \
void VScalKernelImpl<double, isa, block>::Compute(const double a, double* x) \
const { \
platform::dynload::cblas_dscal(this->num_, a, x, 1); \
}
FOR_EACH_ISA(MKL_FLOAT, kGT16);
FOR_EACH_ISA_BLOCK(MKL_DOUBLE);
if (useMKL(d)) {
this->Compute = VScalMKL<T>;
return;
}
#endif
#define INTRI8_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute( \
const float a, const float* x, float* y) const { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI8_INPLACE_FLOAT(isa) \
template <> \
void VScalKernelImpl<float, isa, kEQ8>::Compute(const float a, float* x) \
const { \
__m256 tmp; \
__m256 scalar = _mm256_set1_ps(a); \
tmp = _mm256_loadu_ps(x); \
tmp = _mm256_mul_ps(tmp, scalar); \
_mm256_storeu_ps(x, tmp); \
this->Compute = VScalRefer<T>;
}
#ifdef PADDLE_WITH_XBYAK
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
INTRI8_INPLACE_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI8_INPLACE_FLOAT(jit::avx2);
private:
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
INTRI8_INPLACE_FLOAT(jit::avx512f);
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VScalKernelImpl<float>::useJIT(int d) {
return gen::VXXJitCode::init(d, 1);
}
#endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT
#undef INTRI8_INPLACE_FLOAT
#undef MKL_FLOAT
#undef MKL_DOUBLE
#ifdef PADDLE_WITH_MKLML
template <>
bool VScalKernelImpl<float>::useMKL(int d) {
return d > 512;
}
template <>
bool VScalKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
/* VAddBias JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
template <typename T>
class VAddBiasKernelImpl : public VAddBiasKernel<T> {
public:
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() { this->num_ = d; }
void Compute(const T a, const T* x, T* y) const override {
for (int i = 0; i < this->num_; ++i) {
y[i] = x[i] + a;
DECLARE_STATIC_FUNC;
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
}
};
#define INTRI8_FLOAT(isa) \
template <> \
void VAddBiasKernelImpl<float, isa, kEQ8>::Compute( \
const float a, const float* x, float* y) const { \
__m256 tmp = _mm256_loadu_ps(x); \
tmp = _mm256_add_ps(tmp, _mm256_set1_ps(a)); \
_mm256_storeu_ps(y, tmp); \
}
#endif
#define INTRI16_FLOAT(isa) \
template <> \
void VAddBiasKernelImpl<float, isa, kEQ16>::Compute( \
const float a, const float* x, float* y) const { \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = _mm256_add_ps(tmp0, _mm256_set1_ps(a)); \
tmp1 = _mm256_add_ps(tmp1, _mm256_set1_ps(a)); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
this->Compute = VAddBiasRefer<T>;
}
#ifdef PADDLE_WITH_XBYAK
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
INTRI16_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI16_FLOAT(jit::avx2);
private:
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
INTRI16_FLOAT(jit::avx512f);
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VAddBiasKernelImpl<float>::useJIT(int d) {
return gen::VXXJitCode::init(d, 1);
}
#endif
// TODO(TJ): eq16 test and complete avx512
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
/* VRelu JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
......@@ -467,8 +461,6 @@ class VIdentityKernelImpl : public VIdentityKernel<T> {
void Compute(const T* x, T* y) const override {}
};
REGISTER_JITKERNEL_DEPRECATED(vscal, VScalKernel);
REGISTER_JITKERNEL_DEPRECATED(vaddb, VAddBiasKernel);
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
......
......@@ -409,10 +409,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
vaddbias_ = KernelPool::Instance().template Get<VAddBiasKernel<T>>(d);
}
void Compute(const T* x, T* y) const override {
vscal_->Compute(static_cast<T>(2), x, y);
const T a = static_cast<T>(2), b = static_cast<T>(-1);
vscal_->Compute(&a, x, y, this->num_);
vsigmoid_->Compute(y, y);
vscal_->Compute(static_cast<T>(2), y);
vaddbias_->Compute(static_cast<T>(-1), y, y);
vscal_->Compute(&a, y, y, this->num_);
vaddbias_->Compute(&b, y, y, this->num_);
}
private:
......@@ -472,10 +473,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
_mm256_storeu_ps(y, tmp); \
x += AVX_FLOAT_BLOCK; \
y += AVX_FLOAT_BLOCK; \
vscal_->Compute(2.f, x, y); \
const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(2.f, y); \
vaddbias_->Compute(-1.f, y, y); \
vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(&b, y, y, this->num_); \
}
#define INTRI_GT16_FLOAT(isa, expisa) \
......@@ -502,10 +504,11 @@ class VTanhKernelImpl : public VTanhKernel<T> {
} \
x += this->end_; \
y += this->end_; \
vscal_->Compute(2.f, x, y); \
const float a = 2.f, b = -1.f; \
vscal_->Compute(&a, x, y, this->num_); \
vsigmoid_->Compute(y, y); \
vscal_->Compute(2.f, y); \
vaddbias_->Compute(-1.f, y, y); \
vscal_->Compute(&a, y, y, this->num_); \
vaddbias_->Compute(&b, y, y, this->num_); \
}
#ifdef __AVX__
......
......@@ -129,7 +129,7 @@ TEST(JitKernel, vaddbias) {
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(a, x_data, ztgt_data);
ker->Compute(&a, x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
......@@ -285,10 +285,11 @@ void vtanh_better(
const paddle::operators::math::jitkernel::VAddBiasKernel<float>>&
vaddbias,
const int n, const float* x, float* y) {
vscal->Compute(2.f, x, y);
const float a = 2.f, b = -1.f;
vscal->Compute(&a, x, y, n);
vsigmoid->Compute(y, y);
vscal->Compute(2.f, y);
vaddbias->Compute(-1.f, y, y);
vscal->Compute(&a, y, y, n);
vaddbias->Compute(&b, y, y, n);
}
TEST(JitKernel, vtanh) {
......@@ -537,12 +538,12 @@ TEST(JitKernel, vscal) {
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(a, x_data, ztgt_data);
ker->Compute(&a, x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
auto ttgts1 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(a, y_data);
ker->Compute(&a, y_data, y_data, d);
}
auto ttgte1 = GetCurrentUS();
VLOG(30) << "Vec size " << d
......
......@@ -204,7 +204,10 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
<< "." << (driver_version_ % 100) / 10
<< ", Runtime Version: " << runtime_version_ / 1000
<< "." << (runtime_version_ % 100) / 10;
size_t cudnn_dso_ver = dynload::cudnnGetVersion();
LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
<< ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
<< (cudnn_dso_ver % 100) / 10 << ".";
callback_manager_.reset(new StreamCallbackManager(stream_));
}
......
......@@ -65,51 +65,54 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
* include all needed cudnn functions in HPPL
* different cudnn version has different interfaces
**/
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor); \
__macro(cudnnSetTensor4dDescriptorEx); \
__macro(cudnnSetTensorNdDescriptor); \
__macro(cudnnGetTensorNdDescriptor); \
__macro(cudnnGetConvolutionNdForwardOutputDim); \
__macro(cudnnGetConvolutionForwardAlgorithm); \
__macro(cudnnCreateTensorDescriptor); \
__macro(cudnnDestroyTensorDescriptor); \
__macro(cudnnCreateFilterDescriptor); \
__macro(cudnnSetFilter4dDescriptor); \
__macro(cudnnSetFilterNdDescriptor); \
__macro(cudnnGetFilterNdDescriptor); \
__macro(cudnnSetPooling2dDescriptor); \
__macro(cudnnSetPoolingNdDescriptor); \
__macro(cudnnGetPoolingNdDescriptor); \
__macro(cudnnDestroyFilterDescriptor); \
__macro(cudnnCreateConvolutionDescriptor); \
__macro(cudnnCreatePoolingDescriptor); \
__macro(cudnnDestroyPoolingDescriptor); \
__macro(cudnnSetConvolution2dDescriptor); \
__macro(cudnnDestroyConvolutionDescriptor); \
__macro(cudnnSetConvolutionNdDescriptor); \
__macro(cudnnGetConvolutionNdDescriptor); \
__macro(cudnnDeriveBNTensorDescriptor); \
__macro(cudnnCreateSpatialTransformerDescriptor); \
__macro(cudnnSetSpatialTransformerNdDescriptor); \
__macro(cudnnDestroySpatialTransformerDescriptor); \
__macro(cudnnSpatialTfGridGeneratorForward); \
__macro(cudnnSpatialTfGridGeneratorBackward); \
__macro(cudnnSpatialTfSamplerForward); \
__macro(cudnnSpatialTfSamplerBackward); \
__macro(cudnnCreate); \
__macro(cudnnDestroy); \
__macro(cudnnSetStream); \
__macro(cudnnActivationForward); \
__macro(cudnnConvolutionForward); \
__macro(cudnnConvolutionBackwardBias); \
__macro(cudnnGetConvolutionForwardWorkspaceSize); \
__macro(cudnnTransformTensor); \
__macro(cudnnPoolingForward); \
__macro(cudnnPoolingBackward); \
__macro(cudnnSoftmaxBackward); \
__macro(cudnnSoftmaxForward); \
__macro(cudnnGetVersion); \
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor); \
__macro(cudnnSetTensor4dDescriptorEx); \
__macro(cudnnSetTensorNdDescriptor); \
__macro(cudnnGetTensorNdDescriptor); \
__macro(cudnnGetConvolutionNdForwardOutputDim); \
__macro(cudnnGetConvolutionForwardAlgorithm); \
__macro(cudnnCreateTensorDescriptor); \
__macro(cudnnDestroyTensorDescriptor); \
__macro(cudnnCreateFilterDescriptor); \
__macro(cudnnSetFilter4dDescriptor); \
__macro(cudnnSetFilterNdDescriptor); \
__macro(cudnnGetFilterNdDescriptor); \
__macro(cudnnSetPooling2dDescriptor); \
__macro(cudnnSetPoolingNdDescriptor); \
__macro(cudnnGetPoolingNdDescriptor); \
__macro(cudnnDestroyFilterDescriptor); \
__macro(cudnnCreateConvolutionDescriptor); \
__macro(cudnnCreatePoolingDescriptor); \
__macro(cudnnDestroyPoolingDescriptor); \
__macro(cudnnSetConvolution2dDescriptor); \
__macro(cudnnDestroyConvolutionDescriptor); \
__macro(cudnnSetConvolutionNdDescriptor); \
__macro(cudnnGetConvolutionNdDescriptor); \
__macro(cudnnDeriveBNTensorDescriptor); \
__macro(cudnnCreateSpatialTransformerDescriptor); \
__macro(cudnnSetSpatialTransformerNdDescriptor); \
__macro(cudnnDestroySpatialTransformerDescriptor); \
__macro(cudnnSpatialTfGridGeneratorForward); \
__macro(cudnnSpatialTfGridGeneratorBackward); \
__macro(cudnnSpatialTfSamplerForward); \
__macro(cudnnSpatialTfSamplerBackward); \
__macro(cudnnCreate); \
__macro(cudnnDestroy); \
__macro(cudnnSetStream); \
__macro(cudnnActivationForward); \
__macro(cudnnConvolutionForward); \
__macro(cudnnConvolutionBackwardBias); \
__macro(cudnnGetConvolutionForwardWorkspaceSize); \
__macro(cudnnTransformTensor); \
__macro(cudnnPoolingForward); \
__macro(cudnnPoolingBackward); \
__macro(cudnnSoftmaxBackward); \
__macro(cudnnSoftmaxForward); \
__macro(cudnnGetVersion); \
__macro(cudnnFindConvolutionForwardAlgorithmEx); \
__macro(cudnnFindConvolutionBackwardFilterAlgorithmEx); \
__macro(cudnnFindConvolutionBackwardDataAlgorithmEx); \
__macro(cudnnGetErrorString);
CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
......
......@@ -126,7 +126,8 @@ def __bootstrap__():
if core.is_compiled_with_cuda():
read_env_flags += [
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic'
'fraction_of_gpu_memory_to_use', 'cudnn_deterministic',
'conv_workspace_size_limit', 'cudnn_exhaustive_search'
]
core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)])
......
......@@ -27,6 +27,7 @@ from .tensor import concat
from . import utils
from .. import unique_name
from functools import reduce
from .. import core
__all__ = [
'fc',
......@@ -101,6 +102,7 @@ __all__ = [
'image_resize',
'image_resize_short',
'resize_bilinear',
'resize_nearest',
'gather',
'scatter',
'sequence_scatter',
......@@ -1665,6 +1667,20 @@ def conv2d(input,
pre_bias = helper.create_variable_for_type_inference(dtype)
if use_cudnn:
helper.create_variable(
name="kCUDNNFwdAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
helper.create_variable(
name="kCUDNNBwdDataAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
helper.create_variable(
name="kCUDNNBwdFilterAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
helper.append_op(
type=l_type,
inputs={
......@@ -1678,7 +1694,7 @@ def conv2d(input,
'dilations': dilation,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False
'use_mkldnn': False,
})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
......@@ -5640,7 +5656,8 @@ def image_resize(input,
out_shape=None,
scale=None,
name=None,
resample='BILINEAR'):
resample='BILINEAR',
actual_shape=None):
"""
**Resize a Batch of Images**
......@@ -5650,6 +5667,7 @@ def image_resize(input,
Supporting resample methods:
'BILINEAR' : Bilinear interpolation
'NEAREST' : Nearest neighbor interpolation
Args:
input (Variable): The input tensor of image resize layer,
......@@ -5664,25 +5682,51 @@ def image_resize(input,
Default: None
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
resample(str): The resample method. It can only be 'BILINEAR' currently.
resample(str): The resample method. It supports 'BILINEAR' and 'NEAREST'
currently.
Default: 'BILINEAR'
actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize
according to this given shape rather than
:attr:`out_shape` and :attr:`scale` specifying
shape. That is to say actual_shape has the
highest priority. It is recommended to use
actual_shape instead of :attr:`out_shape` if you
want to specify output shape dynamically. When
using actual_shape to specify output shape, one of
:attr:`out_shape` and :attr:`scale` should also be
set, otherwise errors would be occured in graph
constructing stage.
Default: None
Returns:
Variable: The output is a 4-D tensor of the shape
(num_batches, channls, out_h, out_w).
Raises:
TypeError: out_shape should be a list or tuple or Variable.
TypeError: actual_shape should either be Variable or None.
ValueError: The 'resample' of image_resize can only be 'BILINEAR'
or 'NEAREST' currently.
ValueError: One of out_shape and scale must not be None.
ValueError: out_shape length should be 2.
Examples:
.. code-block:: python
out = fluid.layers.image_resize(input, out_shape=[12, 12])
"""
resample_methods = {'BILINEAR': 'bilinear_interp'}
resample_methods = {
'BILINEAR': 'bilinear',
'NEAREST': 'nearest',
}
if resample not in resample_methods:
raise ValueError(
"The 'resample' of image_resize can only be 'BILINEAR' currently.")
"The 'resample' of image_resize can only be 'BILINEAR' or 'NEAREST' currently."
)
if out_shape is None and scale is None:
raise ValueError("One of out_shape and scale must not be None")
helper = LayerHelper('bilinear_interp', **locals())
raise ValueError("One of out_shape and scale must not be None.")
helper = LayerHelper('interpolate', **locals())
dtype = helper.input_dtype()
def _is_list_or_turple_(data):
......@@ -5692,33 +5736,106 @@ def image_resize(input,
out_w = 0
inputs = {"X": input}
if out_shape is not None:
if not (_is_list_or_turple_(out_shape) and
len(out_shape) == 2) and not isinstance(out_shape, Variable):
raise ValueError('out_shape should be a list or tuple or variable')
if _is_list_or_turple_(out_shape):
out_shape = list(map(int, out_shape))
out_h = out_shape[0]
out_w = out_shape[1]
else:
if isinstance(out_shape, Variable):
warnings.warn("out_shape as Variable type is deprecated, \
it is recommended to use actual_shape instead of \
out_shape to specify output shape dynamically.")
inputs['OutSize'] = out_shape
elif not (_is_list_or_turple_(out_shape)):
raise TypeError("out_shape should be a list or tuple or Variable.")
elif len(out_shape) != 2:
raise ValueError("out_shape length should be 2.")
out_shape = list(map(int, out_shape))
out_h = out_shape[0]
out_w = out_shape[1]
else:
out_h = int(input.shape[2] * scale)
out_w = int(input.shape[3] * scale)
if isinstance(actual_shape, Variable):
inputs["OutSize"] = actual_shape
elif actual_shape is not None:
raise TypeError("actual_shape should either be Variable or None.")
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type=resample_methods[resample],
type='interpolate',
inputs=inputs,
outputs={"Out": out},
attrs={"out_h": out_h,
"out_w": out_w})
attrs={
"out_h": out_h,
"out_w": out_w,
"interp_method": resample_methods[resample]
})
return out
@templatedoc(op_type="bilinear_interp")
def resize_bilinear(input, out_shape=None, scale=None, name=None):
@templatedoc(op_type="interpolate")
def resize_bilinear(input,
out_shape=None,
scale=None,
name=None,
actual_shape=None):
"""
${comment}
Resize input by performing bilinear interpolation based on given
output shape which specified by actual_shape, out_shape and scale
in priority order.
Bilinear interpolation is an extension of linear interpolation for
interpolating functions of two variables (e.g. H-direction and
W-direction in this op) on a rectilinear 2D grid. The key idea is
to perform linear interpolation first in one direction, and then
again in the other direction.
For details of bilinear interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Bilinear_interpolation
Args:
input(${x_type}): ${x_comment}.
out_shape(${out_size_type}): ${out_size_comment}.
scale(float|None): The multiplier for the input height or width. At
least one of out_shape or scale must be set. And out_shape has
a higher priority than scale. Default: None.
name(str|None): The output variable name.
actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize
according to this given shape rather than
:attr:`out_shape` and :attr:`scale` specifying
shape. That is to say actual_shape has the
highest priority. It is recommended to use
actual_shape instead of :attr:`out_shape` if you
want to specify output shape dynamically. When
using actual_shape to specify output shape, one of
:attr:`out_shape` and :attr:`scale` should also be
set, otherwise errors would be occured in graph
constructing stage.
Default: None
Returns:
${out_comment}.
"""
return image_resize(input, out_shape, scale, name, 'BILINEAR', actual_shape)
@templatedoc(op_type="interpolate")
def resize_nearest(input,
out_shape=None,
scale=None,
name=None,
actual_shape=None):
"""
Resize input by performing nearest neighbor interpolation in both the
3rd dimention(in height direction) and the 4th dimention(in width
direction) based on given output shape which specified by actual_shape,
out_shape and scale in priority order.
For details of nearest neighbor interpolation, please refer to Wikipedia:
https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation
Args:
input(${x_type}): ${x_comment}.
......@@ -5730,12 +5847,25 @@ def resize_bilinear(input, out_shape=None, scale=None, name=None):
a higher priority than scale. Default: None.
name(str|None): The output variable name.
actual_shape(Variable): An optional input to specify output shape
dynamically. If provided, image resize
according to this given shape rather than
:attr:`out_shape` and :attr:`scale` specifying
shape. That is to say actual_shape has the
highest priority. It is recommended to use
actual_shape instead of :attr:`out_shape` if you
want to specify output shape dynamically. When
using actual_shape to specify output shape, one of
:attr:`out_shape` and :attr:`scale` should also be
set, otherwise errors would be occured in graph
constructing stage.
Default: None
Returns:
${out_comment}.
"""
return image_resize(input, out_shape, scale, name, 'BILINEAR')
return image_resize(input, out_shape, scale, name, 'NEAREST', actual_shape)
def image_resize_short(input, out_short_len, resample='BILINEAR'):
......
......@@ -67,6 +67,7 @@ class TestConv2dOp(OpTest):
def setUp(self):
self.op_type = "conv2d"
self.use_cudnn = False
self.exhaustive_search = False
self.use_cuda = False
self.use_mkldnn = False
self.data_format = "AnyLayout"
......@@ -98,7 +99,8 @@ class TestConv2dOp(OpTest):
'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format
'data_format': self.data_format,
'exhaustive_search': self.exhaustive_search
}
self.outputs = {'Output': output}
......@@ -361,6 +363,12 @@ class TestDepthwiseConvWithDilation2(TestConv2dOp):
self.op_type = "depthwise_conv2d"
class TestCUDNNExhaustiveSearch(TestConv2dOp):
def init_kernel_type(self):
self.use_cudnn = True
self.exhaustive_search = True
# Please Don't remove the following code.
# Currently, CI use cudnn V5.0 which not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation):
......
......@@ -335,6 +335,12 @@ class TestFP16WithInput1x1Filter1x1CUDNN(TestWithInput1x1Filter1x1):
self.check_output_with_place(place, atol=2e-2)
class TestCUDNNExhaustiveSearch(TestCUDNN):
def init_kernel_type(self):
self.use_cudnn = True
self.exhaustive_search = True
# FIXME(typhoonzero): find a way to determine if
# using cudnn > 6 in python
# class TestWithDilationCUDNN(TestWithDilation):
......
......@@ -98,17 +98,18 @@ class TestDistRunnerBase(object):
strategy.allow_op_delay = False
build_stra = fluid.BuildStrategy()
if args.batch_merge_repeat > 1:
pass_builder = build_stra._create_passes_from_strategy()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.batch_merge_repeat)
if args.use_reduce:
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
else:
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
if args.batch_merge_repeat > 1:
pass_builder = build_stra._create_passes_from_strategy()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.batch_merge_repeat)
exe = fluid.ParallelExecutor(
args.use_cuda,
loss_name=avg_cost.name,
......
......@@ -20,10 +20,44 @@ from op_test import OpTest
import paddle.fluid.core as core
def bilinear_interp_np(input, out_h, out_w, out_size):
def nearest_neighbor_interp_np(X,
out_h,
out_w,
out_size=None,
actual_shape=None):
"""nearest neighbor interpolation implement in shape [N, C, H, W]"""
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
if actual_shape is not None:
out_h = actual_shape[0]
out_w = actual_shape[1]
n, c, in_h, in_w = X.shape
ratio_h = ratio_w = 0.0
if out_h > 1:
ratio_h = (in_h - 1.0) / (out_h - 1.0)
if out_w > 1:
ratio_w = (in_w - 1.0) / (out_w - 1.0)
out = np.zeros((n, c, out_h, out_w))
for i in range(out_h):
in_i = int(ratio_h * i + 0.5)
for j in range(out_w):
in_j = int(ratio_w * j + 0.5)
out[:, :, i, j] = X[:, :, in_i, in_j]
return out.astype(X.dtype)
def bilinear_interp_np(input, out_h, out_w, out_size=None, actual_shape=None):
"""bilinear interpolation implement in shape [N, C, H, W]"""
if out_size is not None:
out_h = out_size[0]
out_w = out_size[1]
if actual_shape is not None:
out_h = actual_shape[0]
out_w = actual_shape[1]
batch_size, channel, in_h, in_w = input.shape
if out_h > 1:
ratio_h = (in_h - 1.0) / (out_h - 1.0)
......@@ -53,18 +87,32 @@ def bilinear_interp_np(input, out_h, out_w, out_size):
return out.astype(input.dtype)
class TestBilinearInterpOp(OpTest):
INTERPOLATE_FUNCS = {
'bilinear': bilinear_interp_np,
'nearest': nearest_neighbor_interp_np,
}
class TestInterpolateOp(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "bilinear_interp"
self.op_type = "interpolate"
input_np = np.random.random(self.input_shape).astype("float32")
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w,
self.out_size)
output_np = INTERPOLATE_FUNCS[self.interp_method](
input_np, self.out_h, self.out_w, self.out_size, self.actual_shape)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {'out_h': self.out_h, 'out_w': self.out_w}
if self.actual_shape is not None:
self.inputs['OutSize'] = self.actual_shape
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method
}
self.outputs = {'Out': output_np}
def test_check_output(self):
......@@ -74,90 +122,209 @@ class TestBilinearInterpOp(OpTest):
self.check_grad(['X'], 'Out', in_place=True)
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 4, 4]
self.out_h = 2
self.out_w = 2
self.out_size = np.array([3, 3]).astype("int32")
class TestCase1(TestBilinearInterpOp):
class TestBilinearInterpCase1(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
class TestCase2(TestBilinearInterpOp):
class TestBilinearInterpCase2(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
class TestCase3(TestBilinearInterpOp):
class TestBilinearInterpCase3(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
class TestCase4(TestBilinearInterpOp):
class TestBilinearInterpCase4(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.out_size = np.array([2, 2]).astype("int32")
class TestCase5(TestBilinearInterpOp):
class TestBilinearInterpCase5(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.out_size = np.array([11, 11]).astype("int32")
class TestCase6(TestBilinearInterpOp):
class TestBilinearInterpCase6(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
self.out_size = np.array([65, 129]).astype("int32")
class TestBilinearInterpOpUint8(OpTest):
class TestBilinearInterpActualShape(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.out_size = np.array([66, 40]).astype("int32")
class TestBilinearInterpBigScale(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 4, 64, 32]
self.out_h = 100
self.out_w = 50
self.out_size = np.array([101, 51]).astype('int32')
class TestInterpolateOpUint8(OpTest):
def setUp(self):
self.out_size = None
self.actual_shape = None
self.init_test_case()
self.op_type = "bilinear_interp"
self.op_type = "interpolate"
input_np = np.random.randint(
low=0, high=256, size=self.input_shape).astype("uint8")
output_np = bilinear_interp_np(input_np, self.out_h, self.out_w,
self.out_size)
output_np = INTERPOLATE_FUNCS[self.interp_method](
input_np, self.out_h, self.out_w, self.out_size, self.actual_shape)
self.inputs = {'X': input_np}
if self.out_size is not None:
self.inputs['OutSize'] = self.out_size
self.attrs = {'out_h': self.out_h, 'out_w': self.out_w}
self.attrs = {
'out_h': self.out_h,
'out_w': self.out_w,
'interp_method': self.interp_method
}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output_with_place(place=core.CPUPlace(), atol=1)
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [1, 3, 9, 6]
self.out_h = 10
self.out_w = 9
class TestCase1Uint8(TestBilinearInterpOpUint8):
class TestBilinearInterpCase1Uint8(TestInterpolateOpUint8):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 128, 64]
self.out_h = 120
self.out_w = 50
class TestBilinearInterpCase2Uint8(TestInterpolateOpUint8):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [4, 1, 7, 8]
self.out_h = 5
self.out_w = 13
self.out_size = np.array([6, 15]).astype("int32")
class TestNearestNeighborInterpCase1(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
class TestNearestNeighborInterpCase2(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
class TestNearestNeighborInterpCase3(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
class TestNearestNeighborInterpCase4(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 1
self.out_w = 1
self.out_size = np.array([2, 2]).astype("int32")
class TestNearestNeighborInterpCase5(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 3, 9, 6]
self.out_h = 12
self.out_w = 12
self.out_size = np.array([11, 11]).astype("int32")
class TestNearestNeighborInterpCase6(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [1, 1, 128, 64]
self.out_h = 64
self.out_w = 128
self.out_size = np.array([65, 129]).astype("int32")
class TestNearestNeighborInterpActualShape(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16]
self.out_h = 64
self.out_w = 32
self.out_size = np.array([66, 40]).astype("int32")
class TestNearestNeighborInterpBigScale(TestInterpolateOp):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 4, 64, 32]
self.out_h = 100
self.out_w = 50
self.out_size = np.array([101, 51]).astype('int32')
class TestNearestNeighborInterpCase1Uint8(TestInterpolateOpUint8):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [2, 3, 128, 64]
self.out_h = 120
self.out_w = 50
class TestCase2Uint8(TestBilinearInterpOpUint8):
class TestNearestNeighborInterpCase2Uint8(TestInterpolateOpUint8):
def init_test_case(self):
self.interp_method = 'nearest'
self.input_shape = [4, 1, 7, 8]
self.out_h = 5
self.out_w = 13
......
......@@ -496,6 +496,16 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(output)
print(str(program))
def test_resize_nearest(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[3, 9, 6], dtype="float32")
output = layers.resize_nearest(x, out_shape=[12, 12])
self.assertIsNotNone(output)
output = layers.resize_nearest(x, scale=3)
self.assertIsNotNone(output)
print(str(program))
def test_polygon_box_transform(self):
program = Program()
with program_guard(program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册