提交 9f8f0fc2 编写于 作者: D Dun 提交者: chengduo

Memory optimization of depthwise conv op and group norm op (#15313)

* mem opt

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* test=develop

* refine code  test=develop

* refine code  test=develop

* refine code  test=develop

* refine code  test=develop

* refine with cub test=develop

* fix mkldnn test && remove comments && test=develop

* polish code && test=develop

* add only_forward test && test=develop
上级 9252aa41
...@@ -94,4 +94,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS ...@@ -94,4 +94,5 @@ cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass multi_batch_merge_pass fuse_elewise_add_act_pass multi_batch_merge_pass
fuse_relu_depthwise_conv_pass
memory_optimize_pass lock_free_optimize_pass) memory_optimize_pass lock_free_optimize_pass)
...@@ -55,6 +55,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -55,6 +55,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
} }
// Add op fusion. // Add op fusion.
if (strategy.fuse_relu_depthwise_conv_) {
AppendPass("fuse_relu_depthwise_conv_pass");
}
if (strategy.fuse_elewise_add_act_ops_) { if (strategy.fuse_elewise_add_act_ops_) {
auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass"); auto fuse_elewise_add_act_pass = AppendPass("fuse_elewise_add_act_pass");
// Add a graph viz pass to record a graph. // Add a graph viz pass to record a graph.
...@@ -210,6 +213,12 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -210,6 +213,12 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Set<const std::vector<OpDesc *>>( pass->Set<const std::vector<OpDesc *>>(
kAllOpDescs, kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps())); new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
} else if (pass->Type() == "fuse_relu_depthwise_conv_pass") {
if (!use_cuda) {
LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on "
"GPU, skipped.";
continue;
}
} }
graph = pass->Apply(std::move(graph)); graph = pass->Apply(std::move(graph));
} }
...@@ -220,6 +229,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -220,6 +229,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
USE_PASS(fuse_relu_depthwise_conv_pass);
USE_PASS(fuse_elewise_add_act_pass); USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_PASS(multi_batch_merge_pass); USE_PASS(multi_batch_merge_pass);
......
...@@ -74,6 +74,8 @@ struct BuildStrategy { ...@@ -74,6 +74,8 @@ struct BuildStrategy {
bool fuse_elewise_add_act_ops_{false}; bool fuse_elewise_add_act_ops_{false};
bool fuse_relu_depthwise_conv_{false};
bool memory_optimize_{false}; bool memory_optimize_{false};
bool memory_early_delete_{false}; bool memory_early_delete_{false};
......
...@@ -70,6 +70,7 @@ if(WITH_MKLDNN) ...@@ -70,6 +70,7 @@ if(WITH_MKLDNN)
endif() endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector )
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_relu_depthwise_conv_pass.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
graph = FuseReluDepthwiseConv(std::move(graph), true);
graph = FuseReluDepthwiseConv(std::move(graph), false);
return graph;
}
std::unique_ptr<ir::Graph> FuseReluDepthwiseConvPass::FuseReluDepthwiseConv(
std::unique_ptr<ir::Graph> graph, bool only_forward) const {
PADDLE_ENFORCE(graph.get());
if (only_forward)
FusePassBase::Init("relu_depthwise_conv_only_forward", graph.get());
else
FusePassBase::Init("relu_depthwise_conv", graph.get());
/*
x ---act--> y ---layer-> z
+----------+
↓ ↓
x' <--act'--- y' <-layer'--- z'
fuse to:
x ---act-layer-> z
|
x' <--act-layer'--- z'
*/
GraphPatternDetector gpd;
auto *pattern = gpd.mutable_pattern();
std::string act_type = "relu";
std::string layer_type = "depthwise_conv2d";
auto *x = pattern->NewNode("x")->AsInput();
auto *y = pattern->NewNode("y")->AsIntermediate();
auto *z = pattern->NewNode("z")->AsOutput();
PDNode *xg = nullptr;
PDNode *yg = nullptr;
PDNode *zg = nullptr;
if (!only_forward) {
xg = pattern->NewNode("xg")->AsOutput();
yg = pattern->NewNode("yg")->AsIntermediate();
zg = pattern->NewNode("zg")->AsInput();
}
PDNode *act_g = nullptr;
PDNode *layer_g = nullptr;
auto *act = pattern->NewNode("act")->assert_is_op(act_type);
auto *layer = pattern->NewNode("layer")->assert_is_op(layer_type);
if (!only_forward) {
act_g = pattern->NewNode("act_g")->assert_is_op(act_type + "_grad");
layer_g = pattern->NewNode("layer_g")->assert_is_op(layer_type + "_grad");
}
act->LinksFrom({x}).LinksTo({y});
layer->LinksFrom({y}).LinksTo({z});
if (!only_forward) {
layer_g->LinksFrom({y, zg}).LinksTo({yg});
act_g->LinksFrom({y, yg}).LinksTo({xg});
}
int count = 0;
std::unordered_set<const Node *> need_removed_nodes;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle FuseReluDepthwiseConv fuse";
// 1. turn on fuse option
auto *layer_op = subgraph.at(layer)->Op();
layer_op->SetAttr("use_cudnn", false);
layer_op->SetAttr("fuse_relu_before_depthwise_conv", true);
OpDesc *layer_g_op = nullptr;
if (!only_forward) {
layer_g_op = subgraph.at(layer_g)->Op();
layer_g_op->SetAttr("use_cudnn", false);
layer_g_op->SetAttr("fuse_relu_before_depthwise_conv", true);
}
// 2. connect x to layer and layer_g, layer_g to xg
auto *y_var = subgraph.at(y)->Var();
auto *x_var = subgraph.at(x)->Var();
VarDesc *yg_var = nullptr;
VarDesc *xg_var = nullptr;
if (!only_forward) {
yg_var = subgraph.at(yg)->Var();
xg_var = subgraph.at(xg)->Var();
}
PADDLE_ENFORCE_EQ(layer_op->Input("Input").size(), 1);
PADDLE_ENFORCE_EQ(layer_op->Input("Input")[0], y_var->Name());
layer_op->SetInput("Input", {x_var->Name()});
subgraph.at(layer)->inputs.push_back(subgraph.at(x));
subgraph.at(x)->outputs.push_back(subgraph.at(layer));
VLOG(4) << "replace " << y_var->Name() << " -> " << x_var->Name();
if (!only_forward) {
PADDLE_ENFORCE_EQ(layer_g_op->Input("Input").size(), 1);
PADDLE_ENFORCE_EQ(layer_g_op->Input("Input")[0], y_var->Name());
layer_g_op->SetInput("Input", {x_var->Name()});
subgraph.at(layer_g)->inputs.push_back(subgraph.at(x));
subgraph.at(x)->outputs.push_back(subgraph.at(layer_g));
PADDLE_ENFORCE_EQ(layer_g_op->Output(GradVarName("Input")).size(), 1);
PADDLE_ENFORCE_EQ(layer_g_op->Output(GradVarName("Input"))[0],
yg_var->Name());
layer_g_op->SetOutput(GradVarName("Input"), {xg_var->Name()});
subgraph.at(layer_g)->outputs.push_back(subgraph.at(xg));
subgraph.at(xg)->inputs.push_back(subgraph.at(layer_g));
VLOG(4) << "replace " << yg_var->Name() << " -> " << xg_var->Name();
}
// 3. delete y, yg, act, act_g
if (only_forward) {
need_removed_nodes.insert({subgraph.at(y), subgraph.at(act)});
} else {
need_removed_nodes.insert({subgraph.at(y), subgraph.at(yg),
subgraph.at(act), subgraph.at(act_g)});
}
count++;
};
gpd(graph.get(), handler);
GraphSafeRemoveNodes(graph.get(), need_removed_nodes);
AddStatis(count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_relu_depthwise_conv_pass,
paddle::framework::ir::FuseReluDepthwiseConvPass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the relu and depthwise conv
*/
class FuseReluDepthwiseConvPass : public FusePassBase {
public:
virtual ~FuseReluDepthwiseConvPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
std::unique_ptr<ir::Graph> FuseReluDepthwiseConv(
std::unique_ptr<ir::Graph> graph, bool only_forward) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -143,7 +143,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -143,7 +143,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
const std::string key = platform::ConvMKLDNNHandler::GetHash( const std::string key = platform::ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups, src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Output("Output")); ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
...@@ -371,7 +371,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -371,7 +371,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::ConvMKLDNNHandler::AppendKey( platform::ConvMKLDNNHandler::AppendKey(
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt, &key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
input->format(), fuse_relu, fuse_residual_conn, input->format(), fuse_relu, fuse_residual_conn,
ctx.op().Output("Output")); ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false; bool need_s8_to_u8 = false;
...@@ -798,7 +798,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -798,7 +798,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const Tensor* input = ctx.Input<Tensor>("Input"); const Tensor* input = ctx.Input<Tensor>("Input");
const Tensor* filter = ctx.Input<Tensor>("Filter"); const Tensor* filter = ctx.Input<Tensor>("Filter");
const Tensor* output = ctx.Input<Tensor>("Output");
const Tensor* output_grad = const Tensor* output_grad =
ctx.Input<Tensor>(framework::GradVarName("Output")); ctx.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input")); Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
...@@ -810,9 +809,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -810,9 +809,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef, filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor"); "Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(output->layout() == DataLayout::kMKLDNN &&
output->format() != memory::format::format_undef,
"Wrong layout/format set for Output tensor");
PADDLE_ENFORCE(output_grad->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE(output_grad->layout() == DataLayout::kMKLDNN &&
output_grad->format() != memory::format::format_undef, output_grad->format() != memory::format::format_undef,
"Wrong layout/format set for output_grad tensor"); "Wrong layout/format set for output_grad tensor");
...@@ -840,18 +836,19 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -840,18 +836,19 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::vectorize2int(filter->dims()); paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d); GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz =
paddle::framework::vectorize2int(output_grad->dims());
auto src_format = input->format(); auto src_format = input->format();
mkldnn::memory::format weights_format = mkldnn::memory::format weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d); GetWeightsFormat(filter->format(), g, is_conv3d);
// Get an unique name from "argument" name of "Output" variable // Get an unique name from "argument" name of "input" and "Filter" variable
// as well as attributes of primitive to be created // as well as attributes of primitive to be created
// This name will be used as key when saving info into device context // This name will be used as key when saving info into device context
const std::string key = platform::ConvMKLDNNHandler::GetHash( const std::string key = platform::ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups, src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Input("Output")); ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
......
...@@ -171,6 +171,9 @@ void Conv2DOpMaker::Make() { ...@@ -171,6 +171,9 @@ void Conv2DOpMaker::Make() {
"use_cudnn", "use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn") "(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_relu_before_depthwise_conv",
"(bool, default false) Only used in cuda depthwise kernel")
.SetDefault(false);
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
...@@ -412,18 +415,43 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -412,18 +415,43 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
customized_type_value); customized_type_value);
} }
class Conv2dGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType(GradOpType());
op->SetInput("Input", Input("Input"));
op->SetInput("Filter", Input("Filter"));
op->SetInput("Bias", Input("Bias"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(op);
}
virtual std::string GradOpType() const {
return this->ForwardOpType() + "_grad";
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker, REGISTER_OPERATOR(conv2d, ops::ConvOp, ops::Conv2DOpMaker,
ops::ConvOpInferVarType, ops::ConvOpInferVarType, ops::Conv2dGradMaker);
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad); REGISTER_OPERATOR(conv2d_grad, ops::ConvOpGrad);
// depthwise convolution op // depthwise convolution op
REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker, REGISTER_OPERATOR(depthwise_conv2d, ops::ConvOp, ops::Conv2DOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::ConvOpInferVarType, ops::Conv2dGradMaker);
REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad); REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad);
REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker, REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker,
......
...@@ -397,13 +397,19 @@ class DepthwiseConvKernel : public framework::OpKernel<T> { ...@@ -397,13 +397,19 @@ class DepthwiseConvKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
bool fuse_relu = context.Attr<bool>("fuse_relu_before_depthwise_conv");
math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
if (fuse_relu) {
math::DepthwiseConvFunctor<DeviceContext, T, true> depthwiseConv;
depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations,
output);
} else {
math::DepthwiseConvFunctor<DeviceContext, T, false> depthwiseConv;
depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations, depthwiseConv(dev_ctx, *input, filter, strides, paddings, dilations,
output); output);
} }
}
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -424,27 +430,42 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> { ...@@ -424,27 +430,42 @@ class DepthwiseConvGradKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
bool fuse_relu = context.Attr<bool>("fuse_relu_before_depthwise_conv");
math::SetConstant<DeviceContext, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
math::DepthwiseConvInputGradFunctor<DeviceContext, T>
depthwiseConvInputGrad;
math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
depthwiseConvFilterGrad;
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, input_grad, static_cast<T>(0)); set_zero(dev_ctx, input_grad, static_cast<T>(0));
if (fuse_relu) {
math::DepthwiseConvInputGradFunctor<DeviceContext, T, true>
depthwiseConvInputGrad;
depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides,
paddings, dilations, input_grad);
} else {
math::DepthwiseConvInputGradFunctor<DeviceContext, T, false>
depthwiseConvInputGrad;
depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides, depthwiseConvInputGrad(dev_ctx, *input, filter, *output_grad, strides,
paddings, dilations, input_grad); paddings, dilations, input_grad);
} }
}
if (filter_grad) { if (filter_grad) {
filter_grad->mutable_data<T>(context.GetPlace()); filter_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, filter_grad, static_cast<T>(0)); set_zero(dev_ctx, filter_grad, static_cast<T>(0));
depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides, paddings, if (fuse_relu) {
dilations, filter_grad); math::DepthwiseConvFilterGradFunctor<DeviceContext, T, true>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides,
paddings, dilations, filter_grad);
} else {
math::DepthwiseConvFilterGradFunctor<DeviceContext, T, false>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(dev_ctx, *input, *output_grad, strides,
paddings, dilations, filter_grad);
}
} }
} }
}; };
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/group_norm_op.h" #include "paddle/fluid/operators/group_norm_op.h"
#include <string>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -102,8 +103,8 @@ class GroupNormGradOp : public framework::OperatorWithKernel { ...@@ -102,8 +103,8 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
// check input // check input
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(X) of GroupNormOp should not be null."); "Input(Y) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mean"), PADDLE_ENFORCE(ctx->HasInput("Mean"),
"Input(Mean) of GroupNormOp should not be null."); "Input(Mean) of GroupNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"), PADDLE_ENFORCE(ctx->HasInput("Variance"),
...@@ -113,7 +114,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel { ...@@ -113,7 +114,7 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
// check output // check output
if (ctx->HasOutput(framework::GradVarName("X"))) { if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y"));
} }
if (ctx->HasOutput(framework::GradVarName("Scale"))) { if (ctx->HasOutput(framework::GradVarName("Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Scale"), ctx->SetOutputDim(framework::GradVarName("Scale"),
...@@ -145,12 +146,36 @@ class GroupNormGradOp : public framework::OperatorWithKernel { ...@@ -145,12 +146,36 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
} }
}; };
class GroupNormGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *op = new framework::OpDesc();
op->SetType("group_norm_grad");
op->SetInput("Scale", Input("Scale"));
op->SetInput("Bias", Input("Bias"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetInput("Y", Output("Y"));
op->SetInput("Mean", Output("Mean"));
op->SetInput("Variance", Output("Variance"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
op->SetOutput(framework::GradVarName("Scale"), InputGrad("Scale"));
op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(group_norm, ops::GroupNormOp, ops::GroupNormOpMaker, REGISTER_OPERATOR(group_norm, ops::GroupNormOp, ops::GroupNormOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::GroupNormGradMaker);
REGISTER_OPERATOR(group_norm_grad, ops::GroupNormGradOp); REGISTER_OPERATOR(group_norm_grad, ops::GroupNormGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
group_norm, ops::GroupNormKernel<paddle::platform::CPUDeviceContext, float>, group_norm, ops::GroupNormKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -12,12 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <cub/cub.cuh> #include "cub/cub.cuh"
#include "paddle/fluid/operators/group_norm_op.h" #include "paddle/fluid/operators/group_norm_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 };
#define CHECK_CASE(i, flags, kernel_name, args...) \
if (i == flags) { \
kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(args); \
}
// 0 for no scale, no bias
// 1 for has scale, no bias
// 2 for no scale, has bias
// 3 for has scale, has bias
#define UNROLL_ALL_CASES(flags, kernel_name, args...) \
CHECK_CASE(0, flags, kernel_name, args) \
CHECK_CASE(1, flags, kernel_name, args) \
CHECK_CASE(2, flags, kernel_name, args) \
CHECK_CASE(3, flags, kernel_name, args)
template <typename T>
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
typedef cub::WarpReduce<T> WarpReduce;
typename WarpReduce::TempStorage temp_storage;
value = WarpReduce(temp_storage).Sum(value);
if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value);
}
template <typename T> template <typename T>
__global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C,
int imsize, int groups, int imsize, int groups,
...@@ -36,21 +62,11 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, ...@@ -36,21 +62,11 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C,
} }
x_mean /= number * imsize; x_mean /= number * imsize;
x_var /= number * imsize; x_var /= number * imsize;
__shared__ T s_mem[2]; CudaAtomicAddWithWarp(&mean[bid * groups + gid], x_mean);
if (threadIdx.x == 0) { CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var);
s_mem[0] = s_mem[1] = 0;
}
__syncthreads();
paddle::platform::CudaAtomicAdd(&s_mem[0], x_mean);
paddle::platform::CudaAtomicAdd(&s_mem[1], x_var);
__syncthreads();
if (threadIdx.x == 0) {
paddle::platform::CudaAtomicAdd(&mean[bid * groups + gid], s_mem[0]);
paddle::platform::CudaAtomicAdd(&var[bid * groups + gid], s_mem[1]);
}
} }
template <typename T> template <typename T, int flags>
__global__ void GroupNormForward(const T* x, const T* mean, const T* var, __global__ void GroupNormForward(const T* x, const T* mean, const T* var,
const T* scale, const T* bias, int N, int C, const T* scale, const T* bias, int N, int C,
int imsize, int groups, int group_size, int imsize, int groups, int group_size,
...@@ -68,8 +84,8 @@ __global__ void GroupNormForward(const T* x, const T* mean, const T* var, ...@@ -68,8 +84,8 @@ __global__ void GroupNormForward(const T* x, const T* mean, const T* var,
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val = x[(bid * C + ccid) * imsize + imid]; T val = x[(bid * C + ccid) * imsize + imid];
val = (val - x_mean) * var_inv; val = (val - x_mean) * var_inv;
if (scale) val *= scale[gid * group_size + cid]; if (flags & kHasScale) val *= scale[gid * group_size + cid];
if (bias) val += bias[gid * group_size + cid]; if (flags & kHasBias) val += bias[gid * group_size + cid];
y[(bid * C + ccid) * imsize + imid] = val; y[(bid * C + ccid) * imsize + imid] = val;
} }
} }
...@@ -115,93 +131,87 @@ class GroupNormKernel<platform::CUDADeviceContext, T> ...@@ -115,93 +131,87 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
if (bias) bias_data = bias->data<T>(); if (bias) bias_data = bias->data<T>();
int imsize = x_dims[2] * x_dims[3]; int imsize = x_dims[2] * x_dims[3];
int block_size = std::min(512, imsize); int block_size = std::min(1024, imsize);
dim3 grid(group_size, groups, x_dims[0]); dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1); dim3 threads(block_size, 1, 1);
GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>( GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, x_dims[0], x_dims[1], imsize, groups, group_size, mean_data, x_data, x_dims[0], x_dims[1], imsize, groups, group_size, mean_data,
temp_var_data); temp_var_data);
GroupNormForward<T><<<grid, threads, 0, dev_ctx.stream()>>>( int flags =
x_data, mean_data, temp_var_data, scale_data, bias_data, x_dims[0], (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
x_dims[1], imsize, groups, group_size, epsilon, y_data, var_data); UNROLL_ALL_CASES(flags, GroupNormForward, x_data, mean_data, temp_var_data,
scale_data, bias_data, x_dims[0], x_dims[1], imsize,
groups, group_size, epsilon, y_data, var_data);
} }
}; };
template <typename T> template <typename T, int flags>
__global__ void GroupNormBackwardGetMeanAndVar( __global__ void GroupNormBackwardGetMeanAndVar(const T* x, const T* scale,
const T* x, const T* mean, const T* var, const T* scale, const T* d_y, const T* bias, const T* d_y,
int N, int C, int imsize, int groups, int group_size, T epsilon, T* d_x, int N, int C, int imsize,
T* d_mean, T* d_var, T* d_scale, T* d_bias) { int groups, int group_size,
T epsilon, T* d_mean, T* d_var,
T* d_scale, T* d_bias) {
int gid = blockIdx.y; int gid = blockIdx.y;
int cid = blockIdx.x; int cid = blockIdx.x;
int bid = blockIdx.z; int bid = blockIdx.z;
int number = min(group_size, static_cast<int>(C - gid * group_size)); int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid; int ccid = gid * group_size + cid;
if (ccid >= C) return; if (ccid >= C) return;
T x_mean = mean[bid * groups + gid]; T x_scale = (flags & kHasScale) ? scale[ccid] : 1;
T x_var = var[bid * groups + gid]; T x_bias = (flags & kHasBias) ? bias[ccid] : 0;
T var_inv = 1.0 / sqrt(x_var + epsilon); T x_scale_inv = 0;
T d_var_inv = 0, d_x_mean = 0; if (x_scale != 0) x_scale_inv = 1.0 / x_scale;
T d_mean_data = 0, d_var_data = 0, d_scale_data = 0, d_bias_data = 0; T d_mean_data = 0, d_var_data = 0, d_scale_data = 0, d_bias_data = 0;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T tmp = x[(bid * C + ccid) * imsize + imid]; T val = x[(bid * C + ccid) * imsize + imid] - x_bias;
T val = (tmp - x_mean) * var_inv;
T dval = d_y[(bid * C + ccid) * imsize + imid]; T dval = d_y[(bid * C + ccid) * imsize + imid];
if (d_bias) d_bias_data += dval;
if (d_scale) d_scale_data += val * dval;
if (scale) dval = dval * scale[ccid];
d_var_data += (tmp - x_mean) * dval;
T d_tmp = dval * var_inv;
if (d_x) d_x[(bid * C + ccid) * imsize + imid] = d_tmp;
d_mean_data -= d_tmp;
}
__shared__ T s_mem[4]; d_var_data += val * dval;
if (threadIdx.x == 0) { d_mean_data += dval * x_scale;
s_mem[0] = s_mem[1] = 0;
if (d_scale) s_mem[2] = 0; val = val * x_scale_inv;
if (d_bias) s_mem[3] = 0; d_bias_data += dval;
} d_scale_data += val * dval;
__syncthreads();
paddle::platform::CudaAtomicAdd(&s_mem[0], d_mean_data);
paddle::platform::CudaAtomicAdd(&s_mem[1], d_var_data);
if (d_scale) paddle::platform::CudaAtomicAdd(&s_mem[2], d_scale_data);
if (d_bias) paddle::platform::CudaAtomicAdd(&s_mem[3], d_bias_data);
__syncthreads();
if (threadIdx.x == 0) {
paddle::platform::CudaAtomicAdd(&d_mean[bid * groups + gid], s_mem[0]);
paddle::platform::CudaAtomicAdd(&d_var[bid * groups + gid], s_mem[1]);
if (d_scale) paddle::platform::CudaAtomicAdd(&d_scale[ccid], s_mem[2]);
if (d_bias) paddle::platform::CudaAtomicAdd(&d_bias[ccid], s_mem[3]);
} }
CudaAtomicAddWithWarp(&d_mean[bid * groups + gid], d_mean_data);
CudaAtomicAddWithWarp(&d_var[bid * groups + gid], d_var_data);
if (flags & kHasScale) CudaAtomicAddWithWarp(&d_scale[ccid], d_scale_data);
if (flags & kHasBias) CudaAtomicAddWithWarp(&d_bias[ccid], d_bias_data);
} }
template <typename T> template <typename T, int flags>
__global__ void GroupNormBackward(const T* x, const T* mean, const T* var, __global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale,
const T* d_mean, const T* d_var, int N, int C, const T* bias, const T* var, const T* d_mean,
int imsize, int groups, int group_size, const T* d_var, int N, int C, int imsize,
T epsilon, T* d_x) { int groups, int group_size, T epsilon,
T* d_x) {
int gid = blockIdx.y; int gid = blockIdx.y;
int cid = blockIdx.x; int cid = blockIdx.x;
int bid = blockIdx.z; int bid = blockIdx.z;
int number = min(group_size, static_cast<int>(C - gid * group_size)); int number = min(group_size, static_cast<int>(C - gid * group_size));
int ccid = gid * group_size + cid; int ccid = gid * group_size + cid;
if (ccid >= C) return; if (ccid >= C) return;
T x_mean = mean[bid * groups + gid];
T x_var = var[bid * groups + gid]; T x_var = var[bid * groups + gid];
T d_x_mean = d_mean[bid * groups + gid]; T d_x_mean = d_mean[bid * groups + gid];
T d_var_inv = d_var[bid * groups + gid]; T d_x_var = d_var[bid * groups + gid];
T x_var_inv = 1.0 / sqrt(x_var + epsilon);
T number_inv = 1.0 / (number * imsize);
T x_scale = (flags & kHasScale) ? scale[ccid] : 1;
T x_bias = (flags & kHasBias) ? bias[ccid] : 0;
T x_scale_inv = 0;
if (x_scale != 0) x_scale_inv = 1.0 / x_scale;
T d_x_var =
-1.0 / (2 * (x_var + epsilon) * sqrt(x_var + epsilon)) * d_var_inv;
d_x_mean -= 2 * d_x_var * x_mean;
d_x_var /= number * imsize;
d_x_mean /= number * imsize;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T tmp = x[(bid * C + ccid) * imsize + imid]; T tmp = x[(bid * C + ccid) * imsize + imid];
if (d_x) T v_y = (tmp - x_bias) * x_scale_inv;
d_x[(bid * C + ccid) * imsize + imid] += d_x_mean + tmp * 2 * d_x_var; T dly = d_y[(bid * C + ccid) * imsize + imid];
d_x[(bid * C + ccid) * imsize + imid] =
x_var_inv *
(dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean);
} }
} }
...@@ -211,10 +221,10 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T> ...@@ -211,10 +221,10 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("Y");
auto* mean = ctx.Input<Tensor>("Mean");
auto* var = ctx.Input<Tensor>("Variance"); auto* var = ctx.Input<Tensor>("Variance");
auto* scale = ctx.Input<Tensor>("Scale"); auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y")); auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto groups = ctx.Attr<int>("groups"); const auto groups = ctx.Attr<int>("groups");
...@@ -226,11 +236,7 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T> ...@@ -226,11 +236,7 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
const auto& x_dims = x->dims(); const auto& x_dims = x->dims();
const int group_size = (x_dims[1] - 1) / groups + 1; const int group_size = (x_dims[1] - 1) / groups + 1;
T* d_x_data = nullptr;
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
d_x_data = d_x->data<T>();
}
math::SetConstant<platform::CUDADeviceContext, T> set_zero; math::SetConstant<platform::CUDADeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
...@@ -245,8 +251,9 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T> ...@@ -245,8 +251,9 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
T* temp_mean_data = temp_mean.data<T>(); T* temp_mean_data = temp_mean.data<T>();
auto* x_data = x->data<T>(); auto* x_data = x->data<T>();
T* d_x_data = nullptr;
if (d_x) d_x_data = d_x->data<T>();
auto* y_data = d_y->data<T>(); auto* y_data = d_y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>(); auto* var_data = var->data<T>();
T* d_scale_data = nullptr; T* d_scale_data = nullptr;
if (d_scale) { if (d_scale) {
...@@ -263,18 +270,25 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T> ...@@ -263,18 +270,25 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
const T* scale_data = nullptr; const T* scale_data = nullptr;
if (scale) scale_data = scale->data<T>(); if (scale) scale_data = scale->data<T>();
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();
int imsize = x_dims[2] * x_dims[3]; int imsize = x_dims[2] * x_dims[3];
int block_size = std::min(512, imsize); int block_size = std::min(1024, imsize);
dim3 grid(group_size, groups, x_dims[0]); dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1); dim3 threads(block_size, 1, 1);
GroupNormBackwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>( int flags =
x_data, mean_data, var_data, scale_data, y_data, x_dims[0], x_dims[1], (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
imsize, groups, group_size, epsilon, d_x_data, temp_mean_data, UNROLL_ALL_CASES(flags, GroupNormBackwardGetMeanAndVar, x_data, scale_data,
temp_var_data, d_scale_data, d_bias_data); bias_data, y_data, x_dims[0], x_dims[1], imsize, groups,
GroupNormBackward<T><<<grid, threads, 0, dev_ctx.stream()>>>( group_size, epsilon, temp_mean_data, temp_var_data,
x_data, mean_data, var_data, temp_mean_data, temp_var_data, x_dims[0], d_scale_data, d_bias_data);
x_dims[1], imsize, groups, group_size, epsilon, d_x_data); if (d_x_data != nullptr) {
UNROLL_ALL_CASES(flags, GroupNormBackward, x_data, y_data, scale_data,
bias_data, var_data, temp_mean_data, temp_var_data,
x_dims[0], x_dims[1], imsize, groups, group_size,
epsilon, d_x_data);
}
} }
}; };
......
...@@ -96,10 +96,10 @@ class GroupNormGradKernel : public framework::OpKernel<T> { ...@@ -96,10 +96,10 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("Y");
auto* mean = ctx.Input<Tensor>("Mean");
auto* var = ctx.Input<Tensor>("Variance"); auto* var = ctx.Input<Tensor>("Variance");
auto* scale = ctx.Input<Tensor>("Scale"); auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y")); auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto groups = ctx.Attr<int>("groups"); const auto groups = ctx.Attr<int>("groups");
...@@ -111,19 +111,13 @@ class GroupNormGradKernel : public framework::OpKernel<T> { ...@@ -111,19 +111,13 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
const auto& x_dims = x->dims(); const auto& x_dims = x->dims();
const int group_size = (x_dims[1] - 1) / groups + 1; const int group_size = (x_dims[1] - 1) / groups + 1;
// TODO(liangdun): need to check d_x is null d_x->mutable_data<T>(ctx.GetPlace());
math::SetConstant<DeviceContext, T> set_zero; math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
T* d_x_data = nullptr;
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_x, static_cast<T>(0));
d_x_data = d_x->data<T>();
}
auto* x_data = x->data<T>(); auto* x_data = x->data<T>();
auto* d_x_data = d_x->data<T>();
auto* y_data = d_y->data<T>(); auto* y_data = d_y->data<T>();
auto* mean_data = mean->data<T>();
auto* var_data = var->data<T>(); auto* var_data = var->data<T>();
T* d_scale_data = nullptr; T* d_scale_data = nullptr;
if (d_scale) { if (d_scale) {
...@@ -140,6 +134,8 @@ class GroupNormGradKernel : public framework::OpKernel<T> { ...@@ -140,6 +134,8 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
const T* scale_data = nullptr; const T* scale_data = nullptr;
if (scale) scale_data = scale->data<T>(); if (scale) scale_data = scale->data<T>();
const T* bias_data = nullptr;
if (bias) bias_data = bias->data<T>();
int imsize = x_dims[2] * x_dims[3]; int imsize = x_dims[2] * x_dims[3];
auto* iter_x_data = x_data; auto* iter_x_data = x_data;
...@@ -147,46 +143,45 @@ class GroupNormGradKernel : public framework::OpKernel<T> { ...@@ -147,46 +143,45 @@ class GroupNormGradKernel : public framework::OpKernel<T> {
auto* iter_y_data = y_data; auto* iter_y_data = y_data;
for (int bid = 0; bid < x_dims[0]; bid++) for (int bid = 0; bid < x_dims[0]; bid++)
for (int gid = 0; gid < groups; gid++) { for (int gid = 0; gid < groups; gid++) {
T x_mean = mean_data[bid * groups + gid];
T x_var = var_data[bid * groups + gid]; T x_var = var_data[bid * groups + gid];
T var_inv = 1.0 / sqrt(x_var + epsilon); T var_inv = 1.0 / sqrt(x_var + epsilon);
int number = std::min(group_size, int number = std::min(group_size,
static_cast<int>(x_dims[1] - gid * group_size)); static_cast<int>(x_dims[1] - gid * group_size));
auto* tmp = iter_x_data; T number_inv = 1.0 / (number * imsize);
auto* tmp2 = iter_d_x_data; auto* iter_x_data2 = iter_x_data;
T d_var_inv = 0, d_x_mean = 0; auto* iter_y_data2 = iter_y_data;
T dp_scale = 0, dp_bias = 0;
for (int cid = 0; cid < number; cid++) { for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize; for (int imid = 0; imid < imsize;
imid++, tmp++, iter_y_data++, iter_d_x_data++) { imid++, iter_x_data++, iter_y_data++) {
T val = (tmp[0] - x_mean) * var_inv; T val = iter_x_data[0];
if (bias_data) val -= bias_data[gid * group_size + cid];
T dval = iter_y_data[0]; T dval = iter_y_data[0];
dp_scale += val * dval;
dp_bias += dval * scale_data[gid * group_size + cid];
if (scale_data && scale_data[gid * group_size + cid] != 0)
val /= scale_data[gid * group_size + cid];
if (d_bias_data) d_bias_data[gid * group_size + cid] += dval; if (d_bias_data) d_bias_data[gid * group_size + cid] += dval;
if (d_scale_data) if (d_scale_data)
d_scale_data[gid * group_size + cid] += val * dval; d_scale_data[gid * group_size + cid] += val * dval;
if (scale_data) dval = scale_data[gid * group_size + cid] * dval;
d_var_inv += (tmp[0] - x_mean) * dval;
T d_tmp = dval * var_inv;
if (d_x_data) iter_d_x_data[0] += d_tmp;
d_x_mean -= d_tmp;
} }
} }
T d_x_var =
-1.0 / (2 * (x_var + epsilon) * sqrt(x_var + epsilon)) * d_var_inv;
d_x_mean -= 2 * d_x_var * x_mean;
d_x_var /= number * imsize;
d_x_mean /= number * imsize;
iter_d_x_data = tmp2;
if (d_x_data) {
for (int cid = 0; cid < number; cid++) { for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize; for (int imid = 0; imid < imsize;
imid++, iter_x_data++, iter_d_x_data++) { imid++, iter_d_x_data++, iter_x_data2++, iter_y_data2++) {
iter_d_x_data[0] += d_x_mean; T v_y = iter_x_data2[0];
iter_d_x_data[0] += iter_x_data[0] * 2 * d_x_var; T dly = iter_y_data2[0];
} T dss = dp_scale;
T dbs = dp_bias;
T v_scale = scale_data[gid * group_size + cid];
T v_bias = bias_data[gid * group_size + cid];
v_y -= v_bias;
if (v_scale != 0) v_y /= v_scale;
iter_d_x_data[0] =
(dly * v_scale - number_inv * dss * v_y - number_inv * dbs) *
var_inv;
} }
} }
} }
......
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle { namespace paddle {
...@@ -22,28 +24,11 @@ namespace operators { ...@@ -22,28 +24,11 @@ namespace operators {
namespace math { namespace math {
template <typename T> template <typename T>
__inline__ __device__ T warpReduceSum(T val) { __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
#if CUDA_VERSION < 9000 typedef cub::WarpReduce<T> WarpReduce;
for (int offset = 16; offset > 0; offset /= 2) typename WarpReduce::TempStorage temp_storage;
val += __shfl_down(val, offset); value = WarpReduce(temp_storage).Sum(value);
return val; if (cub::LaneId() == 0) platform::CudaAtomicAdd(sum, value);
#else
#define FULL_MASK 0xffffffff
for (int offset = 16; offset > 0; offset /= 2)
val += __shfl_down_sync(FULL_MASK, val, offset);
return val;
#endif
}
__forceinline__ __device__ unsigned lane_id() {
unsigned ret;
asm volatile("mov.u32 %0, %laneid;" : "=r"(ret));
return ret;
}
__forceinline__ __device__ unsigned warp_id() {
unsigned ret;
asm volatile("mov.u32 %0, %warpid;" : "=r"(ret));
return ret;
} }
#define ARG_DEFINE_KernelDepthwiseConv \ #define ARG_DEFINE_KernelDepthwiseConv \
...@@ -58,7 +43,7 @@ __forceinline__ __device__ unsigned warp_id() { ...@@ -58,7 +43,7 @@ __forceinline__ __device__ unsigned warp_id() {
// A Cuda kernel to compute the depthwise convolution forward pass // A Cuda kernel to compute the depthwise convolution forward pass
// in NCHW format. // in NCHW format.
template <typename T> template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) { for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) { for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
...@@ -87,8 +72,12 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { ...@@ -87,8 +72,12 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
if (h_in >= h_start && h_in < h_end && w_in >= w_start && if (h_in >= h_start && h_in < h_end && w_in >= w_start &&
w_in < w_end) { w_in < w_end) {
const int offset = in_offset + h_in * input_width + w_in; const int offset = in_offset + h_in * input_width + w_in;
if (fuse_relu_before_conv) {
value += weight[weight_offset] * max(0.0f, input_data[offset]);
} else {
value += weight[weight_offset] * input_data[offset]; value += weight[weight_offset] * input_data[offset];
} }
}
weight_offset++; weight_offset++;
} }
} }
...@@ -100,7 +89,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { ...@@ -100,7 +89,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
} }
} }
template <typename T, int c_filter> template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvCFilter( __device__ __inline__ void KernelDepthwiseConvCFilter(
ARG_DEFINE_KernelDepthwiseConv) { ARG_DEFINE_KernelDepthwiseConv) {
const int kWeghtSize = c_filter * c_filter; const int kWeghtSize = c_filter * c_filter;
...@@ -137,10 +126,15 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( ...@@ -137,10 +126,15 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
if (h_in >= 0 && h_in < input_height && w_in >= 0 && if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
w_in < input_width) { w_in < input_width) {
const int offset = in_offset + h_in * input_width + w_in; const int offset = in_offset + h_in * input_width + w_in;
if (fuse_relu_before_conv) {
value += r_weight[h_f * c_filter + w_f] *
max(0.0f, input_data[offset]);
} else {
value += r_weight[h_f * c_filter + w_f] * input_data[offset]; value += r_weight[h_f * c_filter + w_f] * input_data[offset];
} }
} }
} }
}
int index = int index =
((batch * gridDim.x + c_out) * output_height + h_out) * output_width + ((batch * gridDim.x + c_out) * output_height + h_out) * output_width +
w_out; w_out;
...@@ -149,18 +143,19 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( ...@@ -149,18 +143,19 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
} }
} }
template <typename T, int c_filter_multiplier, int c_stride, int c_filter> template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
bool fuse_relu_before_conv>
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
if (c_filter_multiplier == 0) { if (c_filter_multiplier == 0) {
if (c_filter == -1) if (c_filter == -1)
KernelDepthwiseConv<T>( KernelDepthwiseConv<T, fuse_relu_before_conv>(
input_data, filter_data, batch_size, output_channels, output_height, input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width, output_width, input_channels, input_height, input_width,
filter_multiplier, filter_height, filter_width, stride_height, filter_multiplier, filter_height, filter_width, stride_height,
stride_width, padding_height, padding_width, dilate_height, stride_width, padding_height, padding_width, dilate_height,
dilate_width, output_data); dilate_width, output_data);
else else
KernelDepthwiseConvCFilter<T, c_filter>( KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>(
input_data, filter_data, batch_size, output_channels, output_height, input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width, output_width, input_channels, input_height, input_width,
filter_multiplier, filter_height, filter_width, stride_height, filter_multiplier, filter_height, filter_width, stride_height,
...@@ -168,14 +163,14 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { ...@@ -168,14 +163,14 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
dilate_width, output_data); dilate_width, output_data);
} else { } else {
if (c_filter == -1) if (c_filter == -1)
KernelDepthwiseConv<T>(input_data, filter_data, batch_size, KernelDepthwiseConv<T, fuse_relu_before_conv>(
output_channels, output_height, output_width, input_data, filter_data, batch_size, output_channels, output_height,
input_channels, input_height, input_width, output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_height, c_filter_multiplier, filter_height, filter_height, c_stride, c_stride,
c_stride, c_stride, padding_height, padding_width, padding_height, padding_width, dilate_height, dilate_width,
dilate_height, dilate_width, output_data); output_data);
else else
KernelDepthwiseConvCFilter<T, c_filter>( KernelDepthwiseConvCFilter<T, c_filter, fuse_relu_before_conv>(
input_data, filter_data, batch_size, output_channels, output_height, input_data, filter_data, batch_size, output_channels, output_height,
output_width, input_channels, input_height, input_width, output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_height, c_stride, c_stride, c_filter_multiplier, filter_height, filter_height, c_stride, c_stride,
...@@ -186,17 +181,18 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { ...@@ -186,17 +181,18 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
// CUDA kernel to compute the depthwise convolution backprop w.r.t input. // CUDA kernel to compute the depthwise convolution backprop w.r.t input.
#define ARG_DEFINE_KernelDepthwiseConvInputGrad \ #define ARG_DEFINE_KernelDepthwiseConvInputGrad \
const T *const output_grad_data, const T *const filter_data, \ const T *const input_data, const T *const output_grad_data, \
const int batch_size, const int output_channels, \ const T *const filter_data, const int batch_size, \
const int output_height, const int output_width, \ const int output_channels, const int output_height, \
const int input_channels, const int input_height, const int input_width, \ const int output_width, const int input_channels, \
const int input_height, const int input_width, \
const int filter_multiplier, const int filter_height, \ const int filter_multiplier, const int filter_height, \
const int filter_width, const int stride_height, const int stride_width, \ const int filter_width, const int stride_height, const int stride_width, \
const int padding_height, const int padding_width, \ const int padding_height, const int padding_width, \
const int dilate_height, const int dilate_width, \ const int dilate_height, const int dilate_width, \
T *const input_grad_data T *const input_grad_data
template <typename T> template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvInputGrad( __device__ __inline__ void KernelDepthwiseConvInputGrad(
ARG_DEFINE_KernelDepthwiseConvInputGrad) { ARG_DEFINE_KernelDepthwiseConvInputGrad) {
for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) { for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
...@@ -217,6 +213,15 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad( ...@@ -217,6 +213,15 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
int w_out_end = w_in + padding_width; int w_out_end = w_in + padding_width;
T value = 0; T value = 0;
int index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in;
if (fuse_relu_before_conv) {
if (input_data[index] <= 0) {
input_grad_data[index] = 0;
continue;
}
}
for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier; for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier;
c_out++) { c_out++) {
...@@ -242,15 +247,13 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad( ...@@ -242,15 +247,13 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
} }
} }
} }
int index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in;
input_grad_data[index] = value; input_grad_data[index] = value;
} }
} }
} }
template <typename T, int c_filter, int c_filter_multiplier> template <typename T, int c_filter, int c_filter_multiplier,
bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvInputGradCFilter( __device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
ARG_DEFINE_KernelDepthwiseConvInputGrad) { ARG_DEFINE_KernelDepthwiseConvInputGrad) {
const int kWeghtSize = c_filter * c_filter * c_filter_multiplier + 1; const int kWeghtSize = c_filter * c_filter * c_filter_multiplier + 1;
...@@ -276,6 +279,15 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter( ...@@ -276,6 +279,15 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width; int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;
T value = 0; T value = 0;
int index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in;
if (fuse_relu_before_conv) {
if (input_data[index] <= 0) {
input_grad_data[index] = 0;
continue;
}
}
for (int c_i = 0; c_i < filter_multiplier; c_i++) { for (int c_i = 0; c_i < filter_multiplier; c_i++) {
int c_out = c_in * filter_multiplier + c_i; int c_out = c_in * filter_multiplier + c_i;
...@@ -300,34 +312,33 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter( ...@@ -300,34 +312,33 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
} }
} }
} }
int index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in;
input_grad_data[index] = value; input_grad_data[index] = value;
} }
} }
} }
template <typename T, int c_filter_multiplier, int c_stride, int c_filter> template <typename T, int c_filter_multiplier, int c_stride, int c_filter,
bool fuse_relu_before_conv>
__global__ void KernelDepthwiseConvInputGradSp( __global__ void KernelDepthwiseConvInputGradSp(
ARG_DEFINE_KernelDepthwiseConvInputGrad) { ARG_DEFINE_KernelDepthwiseConvInputGrad) {
if (c_filter_multiplier == 0) if (c_filter_multiplier == 0)
KernelDepthwiseConvInputGrad<T>( KernelDepthwiseConvInputGrad<T, fuse_relu_before_conv>(
output_grad_data, filter_data, batch_size, output_channels, input_data, output_grad_data, filter_data, batch_size, output_channels,
output_height, output_width, input_channels, input_height, input_width, output_height, output_width, input_channels, input_height, input_width,
filter_multiplier, filter_height, filter_width, stride_height, filter_multiplier, filter_height, filter_width, stride_height,
stride_width, padding_height, padding_width, dilate_height, stride_width, padding_height, padding_width, dilate_height,
dilate_width, input_grad_data); dilate_width, input_grad_data);
else if (c_filter == -1) else if (c_filter == -1)
KernelDepthwiseConvInputGrad<T>( KernelDepthwiseConvInputGrad<T, fuse_relu_before_conv>(
output_grad_data, filter_data, batch_size, output_channels, input_data, output_grad_data, filter_data, batch_size, output_channels,
output_height, output_width, input_channels, input_height, input_width, output_height, output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_width, c_stride, c_stride, c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
padding_height, padding_width, dilate_height, dilate_width, padding_height, padding_width, dilate_height, dilate_width,
input_grad_data); input_grad_data);
else else
KernelDepthwiseConvInputGradCFilter<T, c_filter, c_filter_multiplier>( KernelDepthwiseConvInputGradCFilter<T, c_filter, c_filter_multiplier,
output_grad_data, filter_data, batch_size, output_channels, fuse_relu_before_conv>(
input_data, output_grad_data, filter_data, batch_size, output_channels,
output_height, output_width, input_channels, input_height, input_width, output_height, output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_width, c_stride, c_stride, c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
padding_height, padding_width, dilate_height, dilate_width, padding_height, padding_width, dilate_height, dilate_width,
...@@ -335,7 +346,7 @@ __global__ void KernelDepthwiseConvInputGradSp( ...@@ -335,7 +346,7 @@ __global__ void KernelDepthwiseConvInputGradSp(
} }
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. // Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
template <typename T> template <typename T, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvFilterGrad( __device__ __inline__ void KernelDepthwiseConvFilterGrad(
const T* output_grad_data, const T* input_data, const int num, const T* output_grad_data, const T* input_data, const int num,
const int output_channels, const int output_height, const int output_width, const int output_channels, const int output_height, const int output_width,
...@@ -347,7 +358,6 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad( ...@@ -347,7 +358,6 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad(
T s = 0; T s = 0;
int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x; int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x;
int lid = lane_id();
for (int image_w = threadIdx.x; image_w < output_width; for (int image_w = threadIdx.x; image_w < output_width;
image_w += blockDim.x) { image_w += blockDim.x) {
...@@ -364,28 +374,28 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad( ...@@ -364,28 +374,28 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad(
if (image_wk < 0 || image_wk >= input_width) continue; if (image_wk < 0 || image_wk >= input_width) continue;
#define gaid(N, C, H, W) \ #define gaid(N, C, H, W) \
((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W)) ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W))
int input_id = ((bid * (gridDim.z / filter_multiplier) +
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
input_data[((bid * (gridDim.z / filter_multiplier) +
kernel_id / filter_multiplier) * kernel_id / filter_multiplier) *
input_height + input_height +
image_hk) * image_hk) *
input_width + input_width +
image_wk]; image_wk;
if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
max(0.0f, input_data[input_id]);
} else {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
input_data[input_id];
}
#undef gaid #undef gaid
} }
} }
} }
#if __CUDA_ARCH__ >= 530 CudaAtomicAddWithWarp(&filter_grad_data[gbid], s);
s = warpReduceSum<T>(s);
if (lid == 0) paddle::platform::CudaAtomicAdd(&filter_grad_data[gbid], s);
#else
paddle::platform::CudaAtomicAdd(&filter_grad_data[gbid], s);
#endif
} }
template <typename T, int c_filter_multiplier> template <typename T, int c_filter_multiplier, bool fuse_relu_before_conv>
__global__ void KernelDepthwiseConvFilterGradSp( __global__ void KernelDepthwiseConvFilterGradSp(
const T* output_grad_data, const T* input_data, const int num, const T* output_grad_data, const T* input_data, const int num,
const int output_channels, const int output_height, const int output_width, const int output_channels, const int output_height, const int output_width,
...@@ -395,14 +405,14 @@ __global__ void KernelDepthwiseConvFilterGradSp( ...@@ -395,14 +405,14 @@ __global__ void KernelDepthwiseConvFilterGradSp(
const int padding_height, const int padding_width, const int dilate_height, const int padding_height, const int padding_width, const int dilate_height,
const int dilate_width, T* filter_grad_data) { const int dilate_width, T* filter_grad_data) {
if (c_filter_multiplier == 0) if (c_filter_multiplier == 0)
KernelDepthwiseConvFilterGrad<T>( KernelDepthwiseConvFilterGrad<T, fuse_relu_before_conv>(
output_grad_data, input_data, num, output_channels, output_height, output_grad_data, input_data, num, output_channels, output_height,
output_width, input_channels, input_height, input_width, output_width, input_channels, input_height, input_width,
filter_multiplier, filter_height, filter_width, stride_height, filter_multiplier, filter_height, filter_width, stride_height,
stride_width, padding_height, padding_width, dilate_height, stride_width, padding_height, padding_width, dilate_height,
dilate_width, filter_grad_data); dilate_width, filter_grad_data);
else else
KernelDepthwiseConvFilterGrad<T>( KernelDepthwiseConvFilterGrad<T, fuse_relu_before_conv>(
output_grad_data, input_data, num, output_channels, output_height, output_grad_data, input_data, num, output_channels, output_height,
output_width, input_channels, input_height, input_width, output_width, input_channels, input_height, input_width,
c_filter_multiplier, filter_height, filter_width, stride_height, c_filter_multiplier, filter_height, filter_width, stride_height,
...@@ -415,8 +425,9 @@ __global__ void KernelDepthwiseConvFilterGradSp( ...@@ -415,8 +425,9 @@ __global__ void KernelDepthwiseConvFilterGradSp(
* Ksize, strides, paddings are two elements. These two elements represent * Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively. * height and width, respectively.
*/ */
template <class T> template <class T, bool fuse_relu_before_conv>
class DepthwiseConvFunctor<platform::CUDADeviceContext, T> { class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
fuse_relu_before_conv> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
...@@ -446,6 +457,10 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> { ...@@ -446,6 +457,10 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
int thread = 512; int thread = 512;
if (output_width > 1024 && output_width <= 2048)
thread = (output_width - 1) / 2 + 1;
else if (output_width > 512 && output_width <= 1024)
thread = output_width;
int blocks = std::min(std::max(thread / output_width, 1), output_height); int blocks = std::min(std::max(thread / output_width, 1), output_height);
dim3 threads(std::min(output_width, thread), blocks, 1); dim3 threads(std::min(output_width, thread), blocks, 1);
dim3 grid(output_channels, batch_size, 1); dim3 grid(output_channels, batch_size, 1);
...@@ -456,8 +471,9 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> { ...@@ -456,8 +471,9 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
stride_height == stride_width && stride_height == c_stride && \ stride_height == stride_width && stride_height == c_stride && \
(ksize_height == ksize_width && ksize_height == c_filter || \ (ksize_height == ksize_width && ksize_height == c_filter || \
c_filter == -1)) { \ c_filter == -1)) { \
KernelDepthwiseConvSp<T, c_filter_multiplier, c_stride, \ KernelDepthwiseConvSp< \
c_filter><<<grid, threads, 0, context.stream()>>>( \ T, c_filter_multiplier, c_stride, c_filter, \
fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \
input_data, filter_data, batch_size, output_channels, output_height, \ input_data, filter_data, batch_size, output_channels, output_height, \
output_width, input_channels, input_height, input_width, \ output_width, input_channels, input_height, input_width, \
filter_multiplier, ksize_height, ksize_width, stride_height, \ filter_multiplier, ksize_height, ksize_width, stride_height, \
...@@ -480,8 +496,9 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> { ...@@ -480,8 +496,9 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
} }
}; };
template <typename T> template <typename T, bool fuse_relu_before_conv>
class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> { class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
fuse_relu_before_conv> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
...@@ -507,11 +524,16 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> { ...@@ -507,11 +524,16 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> {
const int dilate_height = dilations[0]; const int dilate_height = dilations[0];
const int dilate_width = dilations[1]; const int dilate_width = dilations[1];
const T* input_data = input.data<T>();
const T* filter_data = filter.data<T>(); const T* filter_data = filter.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int thread = 512; int thread = 512;
if (input_width > 1024 && input_width <= 2048)
thread = (input_width - 1) / 2 + 1;
else if (input_width > 512 && input_width <= 1024)
thread = input_width;
int blocks = std::min(std::max(thread / input_width, 1), input_height); int blocks = std::min(std::max(thread / input_width, 1), input_height);
dim3 threads(std::min(input_width, thread), blocks, 1); dim3 threads(std::min(input_width, thread), blocks, 1);
dim3 grid(input_channels, batch_size, 1); dim3 grid(input_channels, batch_size, 1);
...@@ -524,13 +546,13 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> { ...@@ -524,13 +546,13 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> {
(ksize_height == ksize_width && ksize_height == c_filter || \ (ksize_height == ksize_width && ksize_height == c_filter || \
c_filter == -1)) { \ c_filter == -1)) { \
KernelDepthwiseConvInputGradSp< \ KernelDepthwiseConvInputGradSp< \
T, c_filter_multiplier, c_stride, \ T, c_filter_multiplier, c_stride, c_filter, \
c_filter><<<grid, threads, 0, context.stream()>>>( \ fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \
output_grad_data, filter_data, batch_size, output_channels, \ input_data, output_grad_data, filter_data, batch_size, \
output_height, output_width, input_channels, input_height, \ output_channels, output_height, output_width, input_channels, \
input_width, filter_multiplier, ksize_height, ksize_width, \ input_height, input_width, filter_multiplier, ksize_height, \
stride_height, stride_width, padding_height, padding_width, \ ksize_width, stride_height, stride_width, padding_height, \
dilate_height, dilate_width, input_grad_data); \ padding_width, dilate_height, dilate_width, input_grad_data); \
return; \ return; \
} }
check_case(1, 1, 3); check_case(1, 1, 3);
...@@ -552,8 +574,9 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> { ...@@ -552,8 +574,9 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> {
} }
}; };
template <typename T> template <typename T, bool fuse_relu_before_conv>
class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T> { class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T,
fuse_relu_before_conv> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
...@@ -583,6 +606,10 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T> { ...@@ -583,6 +606,10 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T> {
T* filter_grad_data = filter_grad->mutable_data<T>(context.GetPlace()); T* filter_grad_data = filter_grad->mutable_data<T>(context.GetPlace());
int block_size = 512; int block_size = 512;
if (output_width > 1024 && output_width <= 2048)
block_size = (output_width - 1) / 2 + 1;
else if (output_width > 512 && output_width <= 1024)
block_size = output_width;
int crop_output_height = int crop_output_height =
std::min(std::max(block_size / output_width, 1), output_height); std::min(std::max(block_size / output_width, 1), output_height);
dim3 grid(ksize_width, ksize_height, output_channels); dim3 grid(ksize_width, ksize_height, output_channels);
...@@ -592,7 +619,8 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T> { ...@@ -592,7 +619,8 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T> {
#define check_case(c_filter_multiplier) \ #define check_case(c_filter_multiplier) \
if (c_filter_multiplier == 0 || c_filter_multiplier == filter_multiplier) { \ if (c_filter_multiplier == 0 || c_filter_multiplier == filter_multiplier) { \
KernelDepthwiseConvFilterGradSp< \ KernelDepthwiseConvFilterGradSp< \
T, c_filter_multiplier><<<grid, threads, 0, context.stream()>>>( \ T, c_filter_multiplier, \
fuse_relu_before_conv><<<grid, threads, 0, context.stream()>>>( \
output_grad_data, input_data, batch_size, output_channels, \ output_grad_data, input_data, batch_size, output_channels, \
output_height, output_width, input_channels, input_height, \ output_height, output_width, input_channels, input_height, \
input_width, filter_multiplier, ksize_height, ksize_width, \ input_width, filter_multiplier, ksize_height, ksize_width, \
...@@ -606,18 +634,31 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T> { ...@@ -606,18 +634,31 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T> {
} }
}; };
template class DepthwiseConvFunctor<platform::CUDADeviceContext, float>; template class DepthwiseConvFunctor<platform::CUDADeviceContext, float, false>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double>; template class DepthwiseConvFunctor<platform::CUDADeviceContext, double, false>;
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, float,
false>;
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
float>; double, false>;
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
float, false>;
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
double, false>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, float, true>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext, double, true>;
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, float,
true>;
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
double>; double, true>;
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
float>; float, true>;
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
double>; double, true>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -26,7 +26,8 @@ namespace math { ...@@ -26,7 +26,8 @@ namespace math {
* \brief Compute the depthwise convolution which include * \brief Compute the depthwise convolution which include
* forward process and backpropagation process * forward process and backpropagation process
*/ */
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T,
bool fuse_relu_before_conv = false>
class DepthwiseConvFunctor { class DepthwiseConvFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
...@@ -36,7 +37,8 @@ class DepthwiseConvFunctor { ...@@ -36,7 +37,8 @@ class DepthwiseConvFunctor {
const std::vector<int>& dilations, framework::Tensor* output); const std::vector<int>& dilations, framework::Tensor* output);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T,
bool fuse_relu_before_conv = false>
class DepthwiseConvInputGradFunctor { class DepthwiseConvInputGradFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
...@@ -48,7 +50,8 @@ class DepthwiseConvInputGradFunctor { ...@@ -48,7 +50,8 @@ class DepthwiseConvInputGradFunctor {
framework::Tensor* input_grad); framework::Tensor* input_grad);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T,
bool fuse_relu_before_conv = false>
class DepthwiseConvFilterGradFunctor { class DepthwiseConvFilterGradFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
......
...@@ -1023,6 +1023,20 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1023,6 +1023,20 @@ All parameter, weight, gradient are variables in Paddle.
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
to fuse elementwise_add_op and activation_op, to fuse elementwise_add_op and activation_op,
it may make the execution faster. Default False)DOC") it may make the execution faster. Default False)DOC")
.def_property(
"fuse_relu_depthwise_conv",
[](const BuildStrategy &self) {
return self.fuse_relu_depthwise_conv_;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE(!self.IsFinalized(), "BuildStrategy is finlaized.");
self.fuse_relu_depthwise_conv_ = b;
},
R"DOC(The type is BOOL, fuse_relu_depthwise_conv indicate whether
to fuse relu and depthwise_conv2d,
it will save GPU memory and may make the execution faster.
This options is only available in GPU devices.
Default False)DOC")
.def_property( .def_property(
"memory_optimize", "memory_optimize",
[](const BuildStrategy &self) { return self.memory_optimize_; }, [](const BuildStrategy &self) { return self.memory_optimize_; },
......
...@@ -76,7 +76,7 @@ def memory_usage(program, batch_size): ...@@ -76,7 +76,7 @@ def memory_usage(program, batch_size):
# Get the var_name list of first block and calculate # Get the var_name list of first block and calculate
total_memory = 0.0 total_memory = 0.0
processed_var_names = set() processed_var_names = set(["@EMPTY@"])
for op in program.global_block().ops: for op in program.global_block().ops:
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
if var_name in processed_var_names: if var_name in processed_var_names:
......
...@@ -1972,6 +1972,7 @@ def conv2d(input, ...@@ -1972,6 +1972,7 @@ def conv2d(input,
'groups': groups, 'groups': groups,
'use_cudnn': use_cudnn, 'use_cudnn': use_cudnn,
'use_mkldnn': False, 'use_mkldnn': False,
'fuse_relu_before_depthwise_conv': False
}) })
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
......
...@@ -42,6 +42,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -42,6 +42,7 @@ class TestParallelExecutorBase(unittest.TestCase):
use_reduce=False, use_reduce=False,
use_ir_memory_optimize=False, use_ir_memory_optimize=False,
fuse_elewise_add_act_ops=False, fuse_elewise_add_act_ops=False,
fuse_relu_depthwise_conv=False,
optimizer=fluid.optimizer.Adam, optimizer=fluid.optimizer.Adam,
use_fast_executor=False, use_fast_executor=False,
enable_sequential_execution=False): enable_sequential_execution=False):
...@@ -60,6 +61,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -60,6 +61,7 @@ class TestParallelExecutorBase(unittest.TestCase):
loss = method(use_feed=feed_dict is not None) loss = method(use_feed=feed_dict is not None)
if optimizer:
optimizer().minimize(loss) optimizer().minimize(loss)
if memory_opt: if memory_opt:
...@@ -76,6 +78,7 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -76,6 +78,7 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \ build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.fuse_relu_depthwise_conv = fuse_relu_depthwise_conv
build_strategy.memory_optimize = use_ir_memory_optimize build_strategy.memory_optimize = use_ir_memory_optimize
build_strategy.enable_sequential_execution = enable_sequential_execution build_strategy.enable_sequential_execution = enable_sequential_execution
if use_cuda and core.is_compiled_with_cuda(): if use_cuda and core.is_compiled_with_cuda():
......
...@@ -70,6 +70,7 @@ class TestConv2dOp(OpTest): ...@@ -70,6 +70,7 @@ class TestConv2dOp(OpTest):
self.exhaustive_search = False self.exhaustive_search = False
self.use_cuda = False self.use_cuda = False
self.use_mkldnn = False self.use_mkldnn = False
self.fuse_relu_before_depthwise_conv = False
self.data_format = "AnyLayout" self.data_format = "AnyLayout"
self.dtype = np.float32 self.dtype = np.float32
self.init_kernel_type() self.init_kernel_type()
...@@ -84,8 +85,17 @@ class TestConv2dOp(OpTest): ...@@ -84,8 +85,17 @@ class TestConv2dOp(OpTest):
} }
input = np.random.random(self.input_size).astype(self.dtype) input = np.random.random(self.input_size).astype(self.dtype)
if not self.testcuda():
self.fuse_relu_before_depthwise_conv = False
if self.fuse_relu_before_depthwise_conv:
input = input - 0.5
input -= (input < 0) * 0.1
input += (input >= 0) * 0.1
input2 = np.maximum(input, 0.0)
else:
input2 = input
filter = np.random.random(self.filter_size).astype(self.dtype) filter = np.random.random(self.filter_size).astype(self.dtype)
output, _, _, _, _ = conv2d_forward_naive(input, filter, self.groups, output, _, _, _, _ = conv2d_forward_naive(input2, filter, self.groups,
conv2d_param) conv2d_param)
output = output.astype(self.dtype) output = output.astype(self.dtype)
...@@ -101,6 +111,8 @@ class TestConv2dOp(OpTest): ...@@ -101,6 +111,8 @@ class TestConv2dOp(OpTest):
'use_cudnn': self.use_cudnn, 'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn, 'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format, 'data_format': self.data_format,
'fuse_relu_before_depthwise_conv':
self.fuse_relu_before_depthwise_conv,
'exhaustive_search': self.exhaustive_search 'exhaustive_search': self.exhaustive_search
} }
self.outputs = {'Output': output} self.outputs = {'Output': output}
...@@ -364,6 +376,78 @@ class TestDepthwiseConvWithDilation2(TestConv2dOp): ...@@ -364,6 +376,78 @@ class TestDepthwiseConvWithDilation2(TestConv2dOp):
self.op_type = "depthwise_conv2d" self.op_type = "depthwise_conv2d"
class TestDepthwiseConvandFuse(TestConv2dOp):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
class TestDepthwiseConv2andFuse(TestConv2dOp):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [3, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
class TestDepthwiseConv3andFuse(TestConv2dOp):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
class TestDepthwiseConvWithDilationandFuse(TestConv2dOp):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.dilations = [2, 2]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
class TestDepthwiseConvWithDilation2andFuse(TestConv2dOp):
def init_test_case(self):
self.fuse_relu_before_depthwise_conv = True
self.use_cuda = True
self.pad = [1, 1]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
self.groups = 3
self.dilations = [2, 2]
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.op_type = "depthwise_conv2d"
class TestCUDNNExhaustiveSearch(TestConv2dOp): class TestCUDNNExhaustiveSearch(TestConv2dOp):
def init_kernel_type(self): def init_kernel_type(self):
self.use_cudnn = True self.use_cudnn = True
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parallel_executor_test_base import TestParallelExecutorBase
import paddle.fluid as fluid
import paddle.fluid.core as core
import numpy as np
import paddle
import paddle.dataset.mnist as mnist
import unittest
import os
MNIST_RECORDIO_FILE = "./mnist_test_pe.recordio"
def norm(*args, **kargs):
return fluid.layers.batch_norm(*args, **kargs)
def sep_conv(input, channel, stride, filter, dilation=1, act=None):
# with scope('depthwise'):
input = fluid.layers.conv2d(
input,
input.shape[1],
filter,
stride,
groups=input.shape[1],
padding=(filter // 2) * dilation,
dilation=dilation,
use_cudnn=False,
bias_attr=False)
input = norm(input)
if act: input = act(input)
# with scope('pointwise'):
input = fluid.layers.conv2d(
input, channel, 1, 1, groups=1, padding=0, bias_attr=False)
input = norm(input)
if act: input = act(input)
return input
def simple_depthwise_net(use_feed):
if use_feed:
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
else:
reader = fluid.layers.open_files(
filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)
hidden = fluid.layers.reshape(img, (-1, 1, 28, 28))
for _ in range(4):
hidden = sep_conv(hidden, channel=200, stride=2, filter=5)
hidden = fluid.layers.relu(hidden)
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestMNIST(TestParallelExecutorBase):
@classmethod
def setUpClass(cls):
os.environ['CPU_NUM'] = str(4)
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
reader = paddle.batch(mnist.train(), batch_size=4)
feeder = fluid.DataFeeder(
feed_list=[ # order is image and label
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
fluid.recordio_writer.convert_reader_to_recordio_file(
MNIST_RECORDIO_FILE, reader, feeder)
def _init_data(self, random=True):
np.random.seed(5)
if random:
img = np.random.random(size=[32, 784]).astype(np.float32)
else:
img = np.ones(shape=[32, 784], dtype='float32')
label = np.ones(shape=[32, 1], dtype='int64')
return img, label
def _compare(self, model, use_cuda, random_data=True, only_forward=False):
if use_cuda and not core.is_compiled_with_cuda():
return
img, label = self._init_data(random_data)
def _optimizer(learning_rate=1e-6):
optimizer = fluid.optimizer.SGD(
learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(1e-6))
return optimizer
if only_forward:
_optimizer = None
fuse_op_first_loss, fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
fuse_relu_depthwise_conv=True,
use_ir_memory_optimize=True,
memory_opt=False,
optimizer=_optimizer)
not_fuse_op_first_loss, not_fuse_op_last_loss = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
use_cuda=use_cuda,
fuse_relu_depthwise_conv=False,
memory_opt=False,
optimizer=_optimizer)
for loss in zip(not_fuse_op_first_loss, fuse_op_first_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(not_fuse_op_last_loss, fuse_op_last_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
def test_simple_depthwise_with_fuse_op(self):
self._compare(simple_depthwise_net, True)
self._compare(simple_depthwise_net, False)
def test_simple_depthwise_with_fuse_op_only_forward(self):
self._compare(simple_depthwise_net, True, only_forward=True)
self._compare(simple_depthwise_net, False, only_forward=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册