提交 9f8f0fc2 编写于 作者: D Dun 提交者: chengduo

Memory optimization of depthwise conv op and group norm op (#15313)

* mem opt

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* refine code  test=develop

* refine code  test=develop

* refine code  test=develop

* refine code  test=develop

* refine with cub test=develop

* fix mkldnn test && remove comments && test=develop

* polish code && test=develop

* add only_forward test && test=develop
上级 9252aa41
......@@ -94,4 +94,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass)
......@@ -55,6 +55,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
}
// Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) {
AppendPass("fuse_relu_depthwise_conv_pass");
}
if (strategy.fuse_elewise_add_act_ops_) {
auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass");
// Add a graph viz pass to record a graph.
......@@ -210,6 +213,12 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Set<const std::vector<OpDesc *>>(
kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
} else if (pass->Type() == "fuse_relu_depthwise_conv_pass") {
if (!use_cuda) {
LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on "
"GPU, skipped.";
continue;
}
}
graph = pass->Apply(std::move(graph));
}
......@@ -220,6 +229,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
} // namespace framework
} // namespace paddle
USE_PASS(fuse_relu_depthwise_conv_pass);
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass);
USE_PASS(multi_batch_merge_pass);
......
......@@ -74,6 +74,8 @@ struct BuildStrategy {
bool fuse_elewise_add_act_ops_{false};
bool fuse_relu_depthwise_conv_{false};
bool memory_optimize_{false};
bool memory_early_delete_{false};
......
......@@ -70,6 +70,7 @@ if(WITH_MKLDNN)
endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector )
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
graph = FuseReluDepthwiseConv(std::move(graph), true);
graph = FuseReluDepthwiseConv(std::move(graph), false);
return graph;
}
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
std::unique_ptr<ir::Graph> graph, bool only_forward) const {
PADDLE_ENFORCE(graph.get());
if (only_forward)
FusePassBase::Init("relu_depthwise_conv_only_forward", graph.get());
else
FusePassBase::Init("relu_depthwise_conv", graph.get());
/*
x ---act--> y ---layer-> z
+----------+
↓ ↓
x' <--act'--- y' <-layer'--- z'
fuse to:
x ---act-layer-> z
|
x' <--act-layer'--- z'
*/
GraphPatternDetector gpd;
auto *pattern = gpd.mutable_pattern();
std::string act_type = "relu";
std::string layer_type = "depthwise_conv2d";
auto *x = pattern->NewNode("x")->AsInput();
auto *y = pattern->NewNode("y")->AsIntermediate();
auto *z = pattern->NewNode("z")->AsOutput();
PDNode *xg = nullptr;
PDNode *yg = nullptr;
PDNode *zg = nullptr;
if (!only_forward) {
xg = pattern->NewNode("xg")->AsOutput();
yg = pattern->NewNode("yg")->AsIntermediate();
zg = pattern->NewNode("zg")->AsInput();
}
PDNode *act_g = nullptr;
PDNode *layer_g = nullptr;
auto *act = pattern->NewNode("act")->assert_is_op(act_type);
auto *layer = pattern->NewNode("layer")->assert_is_op(layer_type);
if (!only_forward) {
act_g = pattern->NewNode("act_g")->assert_is_op(act_type + "_grad");
layer_g = pattern->NewNode("layer_g")->assert_is_op(layer_type + "_grad");
}
act->LinksFrom({x}).LinksTo({y});
layer->LinksFrom({y}).LinksTo({z});
if (!only_forward) {
layer_g->LinksFrom({y, zg}).LinksTo({yg});
act_g->LinksFrom({y, yg}).LinksTo({xg});
}
int count = 0;
std::unordered_set<const Node *> need_removed_nodes;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseReluDepthwiseConv fuse";
// 1. turn on fuse option
auto *layer_op = subgraph.at(layer)->Op();
layer_op->SetAttr("use_cudnn", false);
layer_op->SetAttr("fuse_relu_before_depthwise_conv", true);
OpDesc *layer_g_op = nullptr;
if (!only_forward) {
layer_g_op = subgraph.at(layer_g)->Op();
layer_g_op->SetAttr("use_cudnn", false);
layer_g_op->SetAttr("fuse_relu_before_depthwise_conv", true);
}
// 2. connect x to layer and layer_g, layer_g to xg
auto *y_var = subgraph.at(y)->Var();
auto *x_var = subgraph.at(x)->Var();
VarDesc *yg_var = nullptr;
VarDesc *xg_var = nullptr;
if (!only_forward) {
yg_var = subgraph.at(yg)->Var();
xg_var = subgraph.at(xg)->Var();
}
PADDLE_ENFORCE_EQ(layer_op->Input("Input").size(), 1);
PADDLE_ENFORCE_EQ(layer_op->Input("Input")[0], y_var->Name());
layer_op->SetInput("Input", {x_var->Name()});
subgraph.at(layer)->inputs.push_back(subgraph.at(x));
subgraph.at(x)->outputs.push_back(subgraph.at(layer));
VLOG(4) << "replace " << y_var->Name() << " -> " << x_var->Name();
if (!only_forward) {
PADDLE_ENFORCE_EQ(layer_g_op->Input("Input").size(), 1);
PADDLE_ENFORCE_EQ(layer_g_op->Input("Input")[0], y_var->Name());
layer_g_op->SetInput("Input", {x_var->Name()});
subgraph.at(layer_g)->inputs.push_back(subgraph.at(x));
subgraph.at(x)->outputs.push_back(subgraph.at(layer_g));
PADDLE_ENFORCE_EQ(layer_g_op->Output(GradVarName("Input")).size(), 1);
PADDLE_ENFORCE_EQ(layer_g_op->Output(GradVarName("Input"))[0],
yg_var->Name());
layer_g_op->SetOutput(GradVarName("Input"), {xg_var->Name()});
subgraph.at(layer_g)->outputs.push_back(subgraph.at(xg));
subgraph.at(xg)->inputs.push_back(subgraph.at(layer_g));
VLOG(4) << "replace " << yg_var->Name() << " -> " << xg_var->Name();
}
// 3. delete y, yg, act, act_g
if (only_forward) {
need_removed_nodes.insert({subgraph.at(y), subgraph.at(act)});
} else {
need_removed_nodes.insert({subgraph.at(y), subgraph.at(yg),
subgraph.at(act), subgraph.at(act_g)});
}
count++;
};
gpd(graph.get(), handler);
GraphSafeRemoveNodes(graph.get(), need_removed_nodes);
AddStatis(count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_relu_depthwise_conv_pass,
paddle::framework::ir::FuseReluDepthwiseConvPass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the relu and depthwise conv
*/
class FuseReluDepthwiseConvPass : public FusePassBase {
public:
virtual ~FuseReluDepthwiseConvPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
std::unique_ptr<ir::Graph> FuseReluDepthwiseConv(
std::unique_ptr<ir::Graph> graph, bool only_forward) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -143,7 +143,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Get unique name for storing MKLDNN primitives
const std::string key = platform::ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Output("Output"));
ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd";
std::vector<primitive> pipeline;
......@@ -371,7 +371,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::ConvMKLDNNHandler::AppendKey(
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
input->format(), fuse_relu, fuse_residual_conn,
ctx.op().Output("Output"));
ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false;
......@@ -798,7 +798,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const Tensor* input = ctx.Input<Tensor>("Input");
const Tensor* filter = ctx.Input<Tensor>("Filter");
const Tensor* output = ctx.Input<Tensor>("Output");
const Tensor* output_grad =
ctx.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
......@@ -810,9 +809,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(output->layout() == DataLayout::kMKLDNN &&
output->format() != memory::format::format_undef,
"Wrong layout/format set for Output tensor");
PADDLE_ENFORCE(output_grad->layout() == DataLayout::kMKLDNN &&
output_grad->format() != memory::format::format_undef,
"Wrong layout/format set for output_grad tensor");
......@@ -840,18 +836,19 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
std::vector<int> dst_tz =
paddle::framework::vectorize2int(output_grad->dims());
auto src_format = input->format();
mkldnn::memory::format weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d);
// Get an unique name from "argument" name of "Output" variable
// Get an unique name from "argument" name of "input" and "Filter" variable
// as well as attributes of primitive to be created
// This name will be used as key when saving info into device context
const std::string key = platform::ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Input("Output"));
ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd";
std::vector<primitive> pipeline;
......
......@@ -171,6 +171,9 @@ void Conv2DOpMaker::Make() {
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false);
AddAttr<bool>("fuse_relu_before_depthwise_conv",
"(bool, default false) Only used in cuda depthwise kernel")
.SetDefault(false);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
......@@ -412,18 +415,43 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
customized_type_value);
}
class Conv2dGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType(GradOpType());
op->SetInput("Input", Input("Input"));
op->SetInput("Filter", Input("Filter"));
op->SetInput("Bias", Input("Bias"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(op);
}
virtual std::string GradOpType() const {
return this->ForwardOpType() + "_grad";
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker,
ops::ConvOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::ConvOpInferVarType, ops::Conv2dGradMaker);
REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad);
// depthwise convolution op
REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::ConvOpInferVarType, ops::Conv2dGradMaker);
REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad);
REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker,
......
......@@ -397,13 +397,19 @@ class DepthwiseConvKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
bool fuse_relu = context.Attr<bool>("fuse_relu_before_depthwise_conv");
auto& dev_ctx = context.template device_context<DeviceContext>();
if (fuse_relu) {
math::DepthwiseConvFunctor<DeviceContext, T, true> depthwiseConv;
depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations,
output);
} else {
math::DepthwiseConvFunctor<DeviceContext, T, false> depthwiseConv;
depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations,
output);
}
}
};
template <typename DeviceContext, typename T>
......@@ -424,27 +430,42 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
bool fuse_relu = context.Attr<bool>("fuse_relu_before_depthwise_conv");
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
math::DepthwiseConvInputGradFunctor<DeviceContext, T>
depthwiseConvInputGrad;
math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
depthwiseConvFilterGrad;
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, input_grad, static_cast<T>(0));
if (fuse_relu) {
math::DepthwiseConvInputGradFunctor<DeviceContext, T, true>
depthwiseConvInputGrad;
depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides,
paddings, dilations, input_grad);
} else {
math::DepthwiseConvInputGradFunctor<DeviceContext, T, false>
depthwiseConvInputGrad;
depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides,
paddings, dilations, input_grad);
}
}
if (filter_grad) {
filter_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides, paddings,
dilations, filter_grad);
if (fuse_relu) {
math::DepthwiseConvFilterGradFunctor<DeviceContext, T, true>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides,
paddings, dilations, filter_grad);
} else {
math::DepthwiseConvFilterGradFunctor<DeviceContext, T, false>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides,
paddings, dilations, filter_grad);
}
}
}
};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/group_norm_op.h"
#include <string>
namespace paddle {
namespace operators {
......@@ -102,8 +103,8 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
// check input
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mean"),
"Input(Mean) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"),
......@@ -113,7 +114,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
// check output
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y"));
}
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Scale"),
......@@ -145,12 +146,36 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
}
};
class GroupNormGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *op = new framework::OpDesc();
op->SetType("group_norm_grad");
op->SetInput("Scale", Input("Scale"));
op->SetInput("Bias", Input("Bias"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetInput("Y", Output("Y"));
op->SetInput("Mean", Output("Mean"));
op->SetInput("Variance", Output("Variance"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
op->SetOutput(framework::GradVarName("Scale"), InputGrad("Scale"));
op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(group_norm, ops::GroupNormOp, ops::GroupNormOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::GroupNormGradMaker);
REGISTER_OPERATOR(group_norm_grad, ops::GroupNormGradOp);
REGISTER_OP_CPU_KERNEL(
group_norm, ops::GroupNormKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -12,12 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cub/cub.cuh>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/group_norm_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
namespace paddle {
namespace operators {
enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 };
#define CHECK_CASE(i, flags, kernel_name, args...) \
if (i == flags) { \
kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(args); \
}
// 0 for no scale, no bias
// 1 for has scale, no bias
// 2 for no scale, has bias
// 3 for has scale, has bias
#define UNROLL_ALL_CASES(flags, kernel_name, args...) \
CHECK_CASE(0, flags, kernel_name, args) \
CHECK_CASE(1, flags, kernel_name, args) \
CHECK_CASE(2, flags, kernel_name, args) \
CHECK_CASE(3, flags, kernel_name, args)
template <typename T>
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
typedef cub::WarpReduce<T> WarpReduce;
typename WarpReduce::TempStorage temp_storage;
value = WarpReduce(temp_storage).Sum(value);
if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value);
}
template <typename T>
__global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C,
int imsize, int groups,
......@@ -36,21 +62,11 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C,
}
x_mean /= number * imsize;
x_var /= number * imsize;
__shared__ T s_mem[2];
if (threadIdx.x == 0) {
s_mem[0] = s_mem[1] = 0;
}
__syncthreads();
paddle::platform::CudaAtomicAdd(&s_mem[0], x_mean);
paddle::platform::CudaAtomicAdd(&s_mem[1], x_var);
__syncthreads();
if (threadIdx.x == 0) {
paddle::platform::CudaAtomicAdd(&mean[bid * groups + gid], s_mem[0]);
paddle::platform::CudaAtomicAdd(&var[bid * groups + gid], s_mem[1]);
}
CudaAtomicAddWithWarp(&mean[bid * groups + gid], x_mean);
CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var);
}
template <typename T>
template <typename T, int flags>
__global__ void GroupNormForward(const T* x, const T* mean, const T* var,
const T* scale, const T* bias, int N, int C,
int imsize, int groups, int group_size,
......@@ -68,8 +84,8 @@ __global__ void GroupNormForward(const T* x, const T* mean, const T* var,
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val = x[(bid * C + ccid) * imsize + imid];
val = (val - x_mean) * var_inv;
if (scale) val *= scale[gid * group_size + cid];
if (bias) val += bias[gid * group_size + cid];
if (flags & kHasScale) val *= scale[gid * group_size + cid];
if (flags & kHasBias) val += bias[gid * group_size + cid];
y[(bid * C + ccid) * imsize + imid] = val;
}
}
......@@ -115,93 +131,87 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
if (bias) bias_data = bias->data<T>();
int imsize = x_dims[2] * x_dims[3];
int block_size = std::min(512, imsize);
int block_size = std::min(1024, imsize);
dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1);
GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, x_dims[0], x_dims[1], imsize, groups, group_size, mean_data,
temp_var_data);
GroupNormForward<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, mean_data, temp_var_data, scale_data, bias_data, x_dims[0],
x_dims[1], imsize, groups, group_size, epsilon, y_data, var_data);
int flags =
(scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
UNROLL_ALL_CASES(flags, GroupNormForward, x_data, mean_data, temp_var_data,
scale_data, bias_data, x_dims[0], x_dims[1], imsize,
groups, group_size, epsilon, y_data, var_data);
}
};
template <typename T>
__global__ void GroupNormBackwardGetMeanAndVar(
const T* x, const T* mean, const T* var, const T* scale, const T* d_y,
int N, int C, int imsize, int groups, int group_size, T epsilon, T* d_x,
T* d_mean, T* d_var, T* d_scale, T* d_bias) {
template <typename T, int flags>
__global__ void GroupNormBackwardGetMeanAndVar(const T* x, const T* scale,
const T* bias, const T* d_y,
int N, int C, int imsize,
int groups, int group_size,
T epsilon, T* d_mean, T* d_var,
T* d_scale, T* d_bias) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid;
if (ccid >= C) return;
T x_mean = mean[bid * groups + gid];
T x_var = var[bid * groups + gid];
T var_inv = 1.0 / sqrt(x_var + epsilon);
T d_var_inv = 0, d_x_mean = 0;
T x_scale = (flags & kHasScale) ? scale[ccid] : 1;
T x_bias = (flags & kHasBias) ? bias[ccid] : 0;
T x_scale_inv = 0;
if (x_scale != 0) x_scale_inv = 1.0 / x_scale;
T d_mean_data = 0, d_var_data = 0, d_scale_data = 0, d_bias_data = 0;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T tmp = x[(bid * C + ccid) * imsize + imid];
T val = (tmp - x_mean) * var_inv;
T val = x[(bid * C + ccid) * imsize + imid] - x_bias;
T dval = d_y[(bid * C + ccid) * imsize + imid];
if (d_bias) d_bias_data += dval;
if (d_scale) d_scale_data += val * dval;
if (scale) dval = dval * scale[ccid];
d_var_data += (tmp - x_mean) * dval;
T d_tmp = dval * var_inv;
if (d_x) d_x[(bid * C + ccid) * imsize + imid] = d_tmp;
d_mean_data -= d_tmp;
}
__shared__ T s_mem[4];
if (threadIdx.x == 0) {
s_mem[0] = s_mem[1] = 0;
if (d_scale) s_mem[2] = 0;
if (d_bias) s_mem[3] = 0;
}
__syncthreads();
paddle::platform::CudaAtomicAdd(&s_mem[0], d_mean_data);
paddle::platform::CudaAtomicAdd(&s_mem[1], d_var_data);
if (d_scale) paddle::platform::CudaAtomicAdd(&s_mem[2], d_scale_data);
if (d_bias) paddle::platform::CudaAtomicAdd(&s_mem[3], d_bias_data);
__syncthreads();
if (threadIdx.x == 0) {
paddle::platform::CudaAtomicAdd(&d_mean[bid * groups + gid], s_mem[0]);
paddle::platform::CudaAtomicAdd(&d_var[bid * groups + gid], s_mem[1]);
if (d_scale) paddle::platform::CudaAtomicAdd(&d_scale[ccid], s_mem[2]);
if (d_bias) paddle::platform::CudaAtomicAdd(&d_bias[ccid], s_mem[3]);
d_var_data += val * dval;
d_mean_data += dval * x_scale;
val = val * x_scale_inv;
d_bias_data += dval;
d_scale_data += val * dval;
}
CudaAtomicAddWithWarp(&d_mean[bid * groups + gid], d_mean_data);
CudaAtomicAddWithWarp(&d_var[bid * groups + gid], d_var_data);
if (flags & kHasScale) CudaAtomicAddWithWarp(&d_scale[ccid], d_scale_data);
if (flags & kHasBias) CudaAtomicAddWithWarp(&d_bias[ccid], d_bias_data);
}
template <typename T>
__global__ void GroupNormBackward(const T* x, const T* mean, const T* var,
const T* d_mean, const T* d_var, int N, int C,
int imsize, int groups, int group_size,
T epsilon, T* d_x) {
template <typename T, int flags>
__global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale,
const T* bias, const T* var, const T* d_mean,
const T* d_var, int N, int C, int imsize,
int groups, int group_size, T epsilon,
T* d_x) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid;
if (ccid >= C) return;
T x_mean = mean[bid * groups + gid];
T x_var = var[bid * groups + gid];
T d_x_mean = d_mean[bid * groups + gid];
T d_var_inv = d_var[bid * groups + gid];
T d_x_var = d_var[bid * groups + gid];
T x_var_inv = 1.0 / sqrt(x_var + epsilon);
T number_inv = 1.0 / (number * imsize);
T x_scale = (flags & kHasScale) ? scale[ccid] : 1;
T x_bias = (flags & kHasBias) ? bias[ccid] : 0;
T x_scale_inv = 0;
if (x_scale != 0) x_scale_inv = 1.0 / x_scale;
T d_x_var =
-1.0 / (2 * (x_var + epsilon) * sqrt(x_var + epsilon)) * d_var_inv;
d_x_mean -= 2 * d_x_var * x_mean;
d_x_var /= number * imsize;
d_x_mean /= number * imsize;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T tmp = x[(bid * C + ccid) * imsize + imid];
if (d_x)
d_x[(bid * C + ccid) * imsize + imid] += d_x_mean + tmp * 2 * d_x_var;
T v_y = (tmp - x_bias) * x_scale_inv;
T dly = d_y[(bid * C + ccid) * imsize + imid];
d_x[(bid * C + ccid) * imsize + imid] =
x_var_inv *
(dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean);
}
}
......@@ -211,10 +221,10 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto* x = ctx.Input<Tensor>("X");
auto* mean = ctx.Input<Tensor>("Mean");
auto* x = ctx.Input<Tensor>("Y");
auto* var = ctx.Input<Tensor>("Variance");
auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto groups = ctx.Attr<int>("groups");
......@@ -226,11 +236,7 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
const auto& x_dims = x->dims();
const int group_size = (x_dims[1] - 1) / groups + 1;
T* d_x_data = nullptr;
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
d_x_data = d_x->data<T>();
}
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
......@@ -245,8 +251,9 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
T* temp_mean_data = temp_mean.data<T>();
auto* x_data = x->data<T>();
T* d_x_data = nullptr;
if (d_x) d_x_data = d_x->data<T>();
auto* y_data = d_y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>();
T* d_scale_data = nullptr;
if (d_scale) {
......@@ -263,18 +270,25 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
const T* scale_data = nullptr;
if (scale) scale_data = scale->data<T>();
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();
int imsize = x_dims[2] * x_dims[3];
int block_size = std::min(512, imsize);
int block_size = std::min(1024, imsize);
dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1);
GroupNormBackwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, mean_data, var_data, scale_data, y_data, x_dims[0], x_dims[1],
imsize, groups, group_size, epsilon, d_x_data, temp_mean_data,
temp_var_data, d_scale_data, d_bias_data);
GroupNormBackward<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, mean_data, var_data, temp_mean_data, temp_var_data, x_dims[0],
x_dims[1], imsize, groups, group_size, epsilon, d_x_data);
int flags =
(scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
UNROLL_ALL_CASES(flags, GroupNormBackwardGetMeanAndVar, x_data, scale_data,
bias_data, y_data, x_dims[0], x_dims[1], imsize, groups,
group_size, epsilon, temp_mean_data, temp_var_data,
d_scale_data, d_bias_data);
if (d_x_data != nullptr) {
UNROLL_ALL_CASES(flags, GroupNormBackward, x_data, y_data, scale_data,
bias_data, var_data, temp_mean_data, temp_var_data,
x_dims[0], x_dims[1], imsize, groups, group_size,
epsilon, d_x_data);
}
}
};
......
......@@ -96,10 +96,10 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
auto* x = ctx.Input<Tensor>("X");
auto* mean = ctx.Input<Tensor>("Mean");
auto* x = ctx.Input<Tensor>("Y");
auto* var = ctx.Input<Tensor>("Variance");
auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto groups = ctx.Attr<int>("groups");
......@@ -111,19 +111,13 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
const auto& x_dims = x->dims();
const int group_size = (x_dims[1] - 1) / groups + 1;
// TODO(liangdun): need to check d_x is null
d_x->mutable_data<T>(ctx.GetPlace());
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
T* d_x_data = nullptr;
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_x, static_cast<T>(0));
d_x_data = d_x->data<T>();
}
auto* x_data = x->data<T>();
auto* d_x_data = d_x->data<T>();
auto* y_data = d_y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>();
T* d_scale_data = nullptr;
if (d_scale) {
......@@ -140,6 +134,8 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
const T* scale_data = nullptr;
if (scale) scale_data = scale->data<T>();
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();
int imsize = x_dims[2] * x_dims[3];
auto* iter_x_data = x_data;
......@@ -147,46 +143,45 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
auto* iter_y_data = y_data;
for (int bid = 0; bid < x_dims[0]; bid++)
for (int gid = 0; gid < groups; gid++) {
T x_mean = mean_data[bid * groups + gid];
T x_var = var_data[bid * groups + gid];
T var_inv = 1.0 / sqrt(x_var + epsilon);
int number = std::min(group_size,
static_cast<int>(x_dims[1] - gid * group_size));
auto* tmp = iter_x_data;
auto* tmp2 = iter_d_x_data;
T d_var_inv = 0, d_x_mean = 0;
T number_inv = 1.0 / (number * imsize);
auto* iter_x_data2 = iter_x_data;
auto* iter_y_data2 = iter_y_data;
T dp_scale = 0, dp_bias = 0;
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize;
imid++, tmp++, iter_y_data++, iter_d_x_data++) {
T val = (tmp[0] - x_mean) * var_inv;
imid++, iter_x_data++, iter_y_data++) {
T val = iter_x_data[0];
if (bias_data) val -= bias_data[gid * group_size + cid];
T dval = iter_y_data[0];
dp_scale += val * dval;
dp_bias += dval * scale_data[gid * group_size + cid];
if (scale_data && scale_data[gid * group_size + cid] != 0)
val /= scale_data[gid * group_size + cid];
if (d_bias_data) d_bias_data[gid * group_size + cid] += dval;
if (d_scale_data)
d_scale_data[gid * group_size + cid] += val * dval;
if (scale_data) dval = scale_data[gid * group_size + cid] * dval;
d_var_inv += (tmp[0] - x_mean) * dval;
T d_tmp = dval * var_inv;
if (d_x_data) iter_d_x_data[0] += d_tmp;
d_x_mean -= d_tmp;
}
}
T d_x_var =
-1.0 / (2 * (x_var + epsilon) * sqrt(x_var + epsilon)) * d_var_inv;
d_x_mean -= 2 * d_x_var * x_mean;
d_x_var /= number * imsize;
d_x_mean /= number * imsize;
iter_d_x_data = tmp2;
if (d_x_data) {
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize;
imid++, iter_x_data++, iter_d_x_data++) {
iter_d_x_data[0] += d_x_mean;
iter_d_x_data[0] += iter_x_data[0] * 2 * d_x_var;
}
imid++, iter_d_x_data++, iter_x_data2++, iter_y_data2++) {
T v_y = iter_x_data2[0];
T dly = iter_y_data2[0];
T dss = dp_scale;
T dbs = dp_bias;
T v_scale = scale_data[gid * group_size + cid];
T v_bias = bias_data[gid * group_size + cid];
v_y -= v_bias;
if (v_scale != 0) v_y /= v_scale;
iter_d_x_data[0] =
(dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
var_inv;
}
}
}
......
......@@ -26,7 +26,8 @@ namespace math {
* \brief Compute the depthwise convolution which include
* forward process and backpropagation process
*/
template <typename DeviceContext, typename T>
template <typename DeviceContext, typename T,
bool fuse_relu_before_conv = false>
class DepthwiseConvFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
......@@ -36,7 +37,8 @@ class DepthwiseConvFunctor {
const std::vector<int>& dilations, framework::Tensor* output);
};
template <typename DeviceContext, typename T>
template <typename DeviceContext, typename T,
bool fuse_relu_before_conv = false>
class DepthwiseConvInputGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
......@@ -48,7 +50,8 @@ class DepthwiseConvInputGradFunctor {
framework::Tensor* input_grad);
};
template <typename DeviceContext, typename T>
template <typename DeviceContext, typename T,
bool fuse_relu_before_conv = false>
class DepthwiseConvFilterGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
......
......@@ -1023,6 +1023,20 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
to fuse elementwise_add_op and activation_op,
it may make the execution faster. Default False)DOC")
.def_property(
"fuse_relu_depthwise_conv",
[](const BuildStrategy &self) {
return self.fuse_relu_depthwise_conv_;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.fuse_relu_depthwise_conv_ = b;
},
R"DOC(The type is BOOL, fuse_relu_depthwise_conv indicate whether
to fuse relu and depthwise_conv2d,
it will save GPU memory and may make the execution faster.
This options is only available in GPU devices.
Default False)DOC")
.def_property(
"memory_optimize",
[](const BuildStrategy &self) { return self.memory_optimize_; },
......
......@@ -76,7 +76,7 @@ def memory_usage(program, batch_size):
# Get the var_name list of first block and calculate
total_memory = 0.0
processed_var_names = set()
processed_var_names = set(["@EMPTY@"])
for op in program.global_block().ops:
for var_name in op.output_arg_names:
if var_name in processed_var_names:
......
......@@ -1972,6 +1972,7 @@ def conv2d(input,
'groups': groups,
'use_cudnn': use_cudnn,
'use_mkldnn': False,
'fuse_relu_before_depthwise_conv': False
})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
......
......@@ -42,6 +42,7 @@ class TestParallelExecutorBase(unittest.TestCase):
use_reduce=False,
use_ir_memory_optimize=False,
fuse_elewise_add_act_ops=False,
fuse_relu_depthwise_conv=False,
optimizer=fluid.optimizer.Adam,
use_fast_executor=False,
enable_sequential_execution=False):
......@@ -60,6 +61,7 @@ class TestParallelExecutorBase(unittest.TestCase):
loss = method(use_feed=feed_dict is not None)
if optimizer:
optimizer().minimize(loss)
if memory_opt:
......@@ -76,6 +78,7 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv
build_strategy.memory_optimize = use_ir_memory_optimize
build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda():
......
......@@ -70,6 +70,7 @@ class TestConv2dOp(OpTest):
self.exhaustive_search = False
self.use_cuda = False
self.use_mkldnn = False
self.fuse_relu_before_depthwise_conv = False
self.data_format = "AnyLayout"
self.dtype = np.float32
self.init_kernel_type()
......@@ -84,8 +85,17 @@ class TestConv2dOp(OpTest):
}
input = np.random.random(self.input_size).astype(self.dtype)
if not self.testcuda():
self.fuse_relu_before_depthwise_conv = False
if self.fuse_relu_before_depthwise_conv:
input = input - 0.5
input -= (input < 0) * 0.1
input += (input >= 0) * 0.1
input2 = np.maximum(input, 0.0)
else:
input2 = input
filter = np.random.random(self.filter_size).astype(self.dtype)
output, _, _, _, _ = conv2d_forward_naive(input, filter, self.groups,
output, _, _, _, _ = conv2d_forward_naive(input2, filter, self.groups,
conv2d_param)
output = output.astype(self.dtype)
......@@ -101,6 +111,8 @@ class TestConv2dOp(OpTest):
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format,
'fuse_relu_before_depthwise_conv':
self.fuse_relu_before_depthwise_conv,
'exhaustive_search': self.exhaustive_search
}
self.outputs = {'Output': output}
......@@ -364,6 +376,78 @@ class TestDepthwiseConvWithDilation2(TestConv2dOp):
self.op_type = "depthwise_conv2d"
class TestDepthwiseConvandFuse(TestConv2dOp):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
class TestDepthwiseConv2andFuse(TestConv2dOp):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
class TestDepthwiseConv3andFuse(TestConv2dOp):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
class TestDepthwiseConvWithDilationandFuse(TestConv2dOp):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.dilations = [2, 2]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
class TestDepthwiseConvWithDilation2andFuse(TestConv2dOp):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.dilations = [2, 2]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
class TestCUDNNExhaustiveSearch(TestConv2dOp):
def init_kernel_type(self):
self.use_cudnn = True
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parallel_executor_test_base import TestParallelExecutorBase
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
import paddle
import paddle.dataset.mnist as mnist
import unittest
import os
MNIST_RECORDIO_FILE = "./mnist_test_pe.recordio"
def norm(*args, **kargs):
return fluid.layers.batch_norm(*args, **kargs)
def sep_conv(input, channel, stride, filter, dilation=1, act=None):
# with scope('depthwise'):
input = fluid.layers.conv2d(
input,
input.shape[1],
filter,
stride,
groups=input.shape[1],
padding=(filter // 2) * dilation,
dilation=dilation,
use_cudnn=False,
bias_attr=False)
input = norm(input)
if act: input = act(input)
# with scope('pointwise'):
input = fluid.layers.conv2d(
input, channel, 1, 1, groups=1, padding=0, bias_attr=False)
input = norm(input)
if act: input = act(input)
return input
def simple_depthwise_net(use_feed):
if use_feed:
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else:
reader = fluid.layers.open_files(
filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)
hidden = fluid.layers.reshape(img, (-1, 1, 28, 28))
for _ in range(4):
hidden = sep_conv(hidden, channel=200, stride=2, filter=5)
hidden = fluid.layers.relu(hidden)
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestMNIST(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=4)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
MNIST_RECORDIO_FILE, reader, feeder)
def _init_data(self, random=True):
np.random.seed(5)
if random:
img = np.random.random(size=[32, 784]).astype(np.float32)
else:
img = np.ones(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
return img, label
def _compare(self, model, use_cuda, random_data=True, only_forward=False):
if use_cuda and not core.is_compiled_with_cuda():
return
img, label = self._init_data(random_data)
def _optimizer(learning_rate=1e-6):
optimizer = fluid.optimizer.SGD(
learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(1e-6))
return optimizer
if only_forward:
_optimizer = None
fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
fuse_relu_depthwise_conv=True,
use_ir_memory_optimize=True,
memory_opt=False,
optimizer=_optimizer)
not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
fuse_relu_depthwise_conv=False,
memory_opt=False,
optimizer=_optimizer)
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(not_fuse_op_last_loss, fuse_op_last_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
def test_simple_depthwise_with_fuse_op(self):
self._compare(simple_depthwise_net, True)
self._compare(simple_depthwise_net, False)
def test_simple_depthwise_with_fuse_op_only_forward(self):
self._compare(simple_depthwise_net, True, only_forward=True)
self._compare(simple_depthwise_net, False, only_forward=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册