提交 d4e8d707 编写于 作者: D dongzhihong

Merge remote-tracking branch 'origin/develop' into fix/sequence_pad

test=develop
......@@ -67,8 +67,8 @@ paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size',
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type', 'is_test'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False))
paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
......@@ -103,7 +103,7 @@ paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 's
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None))
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100))
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode'], varargs=None, keywords=None, defaults=(False, -100, False))
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))
......@@ -178,6 +178,7 @@ paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], var
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
......
......@@ -35,13 +35,15 @@ if(WITH_GPU)
all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle graph graph_helper pass)
endif()
cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS graph graph_helper pass)
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
if(WITH_GPU)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass sequential_execution_pass)
else()
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto sequential_execution_pass)
endif()
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h"
#include "paddle/fluid/framework/details/multi_devices_graph_print_pass.h"
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
......@@ -27,6 +28,10 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
public:
explicit ParallelExecutorPassBuilder(const BuildStrategy &strategy)
: ir::PassBuilder(), strategy_(strategy) {
if (strategy_.enable_sequential_execution_) {
AppendPass("sequential_execution_pass");
}
// Add a graph viz pass to record a graph.
if (!strategy_.debug_graphviz_path_.empty()) {
auto viz_pass = AppendPass("graph_viz_pass");
......@@ -110,6 +115,11 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->Erase("nccl_ctxs");
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif
} else if (pass->Type() == "sequential_execution_pass") {
pass->Erase(kAllOpDescs);
pass->Set<const std::vector<OpDesc *>>(
kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
}
graph = pass->Apply(std::move(graph));
}
......@@ -125,3 +135,4 @@ USE_PASS(multi_batch_merge_pass);
USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
USE_PASS(sequential_execution_pass);
......@@ -69,6 +69,8 @@ struct BuildStrategy {
bool enable_data_balance_{false};
bool enable_sequential_execution_{false};
bool fuse_broadcast_op_{false};
// User normally doesn't need to call this API.
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace details {
static bool IsSameOpDesc(OpDesc *op1, OpDesc *op2) {
return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() &&
op1->Outputs() == op2->Outputs();
}
std::unique_ptr<ir::Graph> SequentialExecutionPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
// FIXME(zjl): Insert dependencies between some distributed ops may cause
// the multi_devices_graph_pass fails. So we skip these ops here.
// Indeed, maybe we should not insert dependencies between these ops
// casually, which may cause deadlock easily.
// We should add more skipped distributed ops when found errors in
// multi_devices_graph_pass
static std::unordered_set<std::string> skip_dist_ops{
"send", "recv", "send_barrier", "fetch_barrier"};
auto &ops = Get<const std::vector<OpDesc *>>(kAllOpDescs);
std::vector<ir::Node *> op_node_list;
op_node_list.reserve(ops.size());
std::unordered_map<ir::Node *, size_t> op_deps;
std::unordered_map<ir::Node *, std::unordered_set<ir::Node *>> pending_ops;
std::unordered_set<ir::Node *> ready_ops;
for (ir::Node *node : graph->Nodes()) {
if (!node->IsOp()) continue;
std::unordered_set<ir::Node *> preceding_ops;
for (auto *in : node->inputs) {
PADDLE_ENFORCE(in->IsVar(),
"Preceding Node of Op Nodes must be Var Node");
if (in->inputs.empty()) continue;
PADDLE_ENFORCE(in->inputs.size() == 1 && in->inputs[0]->IsOp(),
"Preceding Op Node of Var Node must be unique");
preceding_ops.insert(in->inputs[0]);
pending_ops[in->inputs[0]].insert(node);
}
op_deps[node] = preceding_ops.size();
if (preceding_ops.empty()) {
ready_ops.insert(node);
}
}
for (auto *op_desc : ops) {
ir::Node *found_node = nullptr;
for (auto *node : ready_ops) {
if (IsSameOpDesc(op_desc, node->Op())) {
PADDLE_ENFORCE(found_node == nullptr,
"Found multiple op_desc in graph: %s", op_desc->Type());
found_node = node;
}
}
PADDLE_ENFORCE_NOT_NULL(found_node, "Cannot find op_desc in graph: %s",
op_desc->Type());
for (auto *pending_op : pending_ops[found_node]) {
if (--op_deps.at(pending_op) == 0) {
ready_ops.insert(pending_op);
}
}
ready_ops.erase(found_node);
if (skip_dist_ops.count(op_desc->Type()) == 0) {
op_node_list.push_back(found_node);
}
}
for (size_t i = 1; i < op_node_list.size(); ++i) {
auto *dep_var = graph->CreateControlDepVar();
op_node_list[i]->inputs.push_back(dep_var);
op_node_list[i - 1]->outputs.push_back(dep_var);
dep_var->outputs.push_back(op_node_list[i]);
dep_var->inputs.push_back(op_node_list[i - 1]);
VLOG(10) << "Add dependencies between " << op_node_list[i - 1]->Name()
<< " and " << op_node_list[i]->Name();
}
return graph;
}
} // namespace details
} // namespace framework
} // namespace paddle
REGISTER_PASS(sequential_execution_pass,
paddle::framework::details::SequentialExecutionPass)
.RequirePassAttr(paddle::framework::details::kAllOpDescs);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace details {
constexpr char kAllOpDescs[] = "all_op_descs";
class SequentialExecutionPass : public ir::Pass {
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -41,6 +41,7 @@ pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base)
pass_library(depthwise_conv_mkldnn_pass base)
pass_library(conv_bias_mkldnn_fuse_pass inference)
pass_library(conv_relu_mkldnn_fuse_pass inference)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference)
......@@ -59,6 +60,7 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
if (WITH_MKLDNN)
cc_test(test_depthwise_conv_mkldnn_pass SRCS depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
endif ()
......@@ -31,7 +31,8 @@ class ConvReLUFusePass : public FusePassBase {
virtual ~ConvReLUFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace ir
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
......@@ -36,6 +37,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
// a->OP0->b
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetrieveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
std::unique_ptr<ir::Graph> DepthwiseConvMKLDNNPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("depthwise_conv_mkldnn_pass", graph.get());
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
pattern->NewNode("depthwise_conv")
->assert_is_op("depthwise_conv2d")
->assert_op_attr("use_mkldnn", true);
int found_depthwise_conv_mkldnn_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(3) << "handle DepthwiseConvMKLDNN fuse";
GET_NODE(depthwise_conv, (*pattern));
depthwise_conv->Op()->SetType("conv2d");
found_depthwise_conv_mkldnn_count++;
};
gpd(graph.get(), handler);
AddStatis(found_depthwise_conv_mkldnn_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(depthwise_conv_mkldnn_pass,
paddle::framework::ir::DepthwiseConvMKLDNNPass);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class DepthwiseConvMKLDNNPass : public FusePassBase {
public:
virtual ~DepthwiseConvMKLDNNPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/depthwise_conv_mkldnn_pass.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn = false) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
op->SetOutput("Out", outputs);
}
// (a, weights, bias)->depthwise conv mkldnn->b
// (b, weights2, bias2)->depthwise conv no mkldnn->c
// (c, weights3, bias3)->conv mkldnn->d
// (d, weights3, bias3)->conv no mkldnn->e
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v : std::vector<std::string>(
{"a", "b", "c", "d", "e", "weights", "bias", "weights2", "bias2",
"weights3", "bias3", "weights4", "bias4"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias" || v == "weights2" || v == "bias2" ||
v == "weights3" || v == "bias3" || v == "weights4" || v == "bias4") {
var->SetPersistable(true);
}
}
// depthwise conv with MKL-DNN
SetOp(&prog, "depthwise_conv2d", "conv1",
std::vector<std::string>({"a", "weights", "bias"}),
std::vector<std::string>({"b"}), true);
// depthwise conv without MKL-DNN
SetOp(&prog, "depthwise_conv2d", "conv2",
std::vector<std::string>({"b", "weights2", "bias2"}),
std::vector<std::string>({"c"}), false);
// conv with MKL-DNN
SetOp(&prog, "conv2d", "conv3",
std::vector<std::string>({"c", "weights3", "bias3"}),
std::vector<std::string>({"d"}), true);
// conv without MKL-dNN
SetOp(&prog, "conv2d", "conv4",
std::vector<std::string>({"d", "weights4", "bias4"}),
std::vector<std::string>({"e"}), false);
return prog;
}
TEST(DepthwiseConvMKLDNNPass, basic) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("depthwise_conv_mkldnn_pass");
struct counters {
int mkldnn_depthwise_conv_nodes;
int other_depthwise_conv_nodes;
int mkldnn_conv_nodes;
int other_conv_nodes;
};
counters before{1, 1, 1, 1};
graph = pass->Apply(std::move(graph));
// initialize counters before loop
counters after{0, 0, 0, 0};
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "conv2d") {
if (boost::get<bool>(op->GetAttr("use_mkldnn")))
after.mkldnn_conv_nodes++;
else
after.other_conv_nodes++;
} else if (op->Type() == "depthwise_conv2d") {
if (boost::get<bool>(op->GetAttr("use_mkldnn")))
after.mkldnn_depthwise_conv_nodes++;
else
after.other_depthwise_conv_nodes++;
}
}
}
EXPECT_EQ(after.other_depthwise_conv_nodes,
before.other_depthwise_conv_nodes);
EXPECT_EQ(after.other_conv_nodes, before.other_conv_nodes);
EXPECT_EQ(after.mkldnn_depthwise_conv_nodes,
before.mkldnn_depthwise_conv_nodes - 1);
EXPECT_EQ(after.mkldnn_conv_nodes, before.mkldnn_conv_nodes + 1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(depthwise_conv_mkldnn_pass);
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
......@@ -32,6 +33,8 @@ void SetOp(ProgramDesc* prog, const std::string& type,
op->SetInput("X", inputs);
}
op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
// a->OP0->b
......
......@@ -23,8 +23,62 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace ir {
namespace {
void CheckProgram(const ProgramDesc &program) {
std::map<int, bool> visit;
#define _INT(role) static_cast<int>(role)
for (size_t i = 0; i < program.Size(); ++i) {
for (OpDesc *op : program.Block(i).AllOps()) {
// For backward compatibility, some program doesn't have role added.
if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
int role_id = boost::get<int>(
op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
visit[role_id] = true;
switch (role_id) {
case _INT(OpRole::kForward):
PADDLE_ENFORCE(
visit.find(_INT(OpRole::kBackward)) == visit.end(),
"Cannot add forward operator before backward operator.");
break;
case _INT(OpRole::kBackward):
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add backward operator before optimize operator.");
break;
case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
_INT(OpRole::kLoss)) == visit.end(),
"Cannot add backward|loss operator before "
"forward|loss operator.");
PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add backward operator before optimize operator.");
break;
case _INT(OpRole::kOptimize):
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
"Optimize operators must follow backward operator.");
break;
case _INT(OpRole::kLRSched):
case _INT(OpRole::kDist):
case _INT(OpRole::kRPC):
case _INT(OpRole::kNotSpecified):
break;
default:
LOG(FATAL) << "Unknown operator role. Don't add new role because "
"you don't know what you are doing.";
}
}
}
#undef _INT
}
} // namespace
Graph::Graph(const ProgramDesc &program) : program_(program) {
CheckProgram(program_);
// Make the nodes id start from 0.
Node::ResetId();
auto var_nodes = InitFromProgram(program_);
......
......@@ -79,6 +79,7 @@ class Analyzer : public OrderedRegistry<PassManager> {
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
#ifdef PADDLE_WITH_MKLDNN
"depthwise_conv_mkldnn_pass", //
"conv_bias_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass", //
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
......@@ -130,6 +131,8 @@ void SetOp(framework::ProgramDesc* prog, const std::string& type,
op->SetType(type);
op->SetInput("Xs", inputs);
op->SetOutput("Xs", outputs);
op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(framework::OpRole::kForward));
}
TEST(DataFlowGraph, Build_IR_Graph) {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout;
using ScopedSpatialTransformerDescriptor =
platform::ScopedSpatialTransformerDescriptor;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
class CUDNNGridSampleOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace");
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
auto* output = ctx.Output<Tensor>("Output");
int n = input->dims()[0];
int c = input->dims()[1];
int h = input->dims()[2];
int w = input->dims()[3];
const int size[4] = {n, c, h, w};
const T* input_data = input->data<T>();
const T* grid_data = grid->data<T>();
T* output_data = output->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
ScopedSpatialTransformerDescriptor st_desc;
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
st_desc.descriptor<T>(4, size);
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(output->dims()));
CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerForward(
handle, cudnn_st_desc, CudnnDataType<T>::kOne(), cudnn_input_desc,
input_data, grid_data, CudnnDataType<T>::kZero(), cudnn_output_desc,
output_data));
}
};
template <typename T>
class CUDNNGridSampleGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace");
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
auto output_grad_dims = output_grad->dims();
const int n = output_grad_dims[0];
const int c = output_grad_dims[1];
const int h = output_grad_dims[2];
const int w = output_grad_dims[3];
const int size[4] = {n, c, h, w};
ScopedSpatialTransformerDescriptor st_dest;
cudnnSpatialTransformerDescriptor_t cudnn_st_dest =
st_dest.descriptor<T>(4, size);
const T* input_data = input->data<T>();
const T* grid_data = grid->data<T>();
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data =
input_grad->mutable_data<T>(output_grad_dims, ctx.GetPlace());
T* grid_grad_data =
grid_grad->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor input_grad_desc;
ScopedTensorDescriptor output_grad_desc;
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t cudnn_input_grad_desc =
input_grad_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(input_grad->dims()));
cudnnTensorDescriptor_t cudnn_output_grad_desc =
output_grad_desc.descriptor<T>(
DataLayout::kNCHW, framework::vectorize2int(output_grad->dims()));
CUDNN_ENFORCE(platform::dynload::cudnnSpatialTfSamplerBackward(
handle, cudnn_st_dest, CudnnDataType<T>::kOne(), cudnn_input_desc,
input_data, CudnnDataType<T>::kZero(), cudnn_input_grad_desc,
input_grad_data, CudnnDataType<T>::kOne(), cudnn_output_grad_desc,
output_grad_data, grid_data, CudnnDataType<T>::kZero(),
grid_grad_data));
}
};
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
REGISTER_OP_KERNEL(grid_sampler, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNGridSampleOpKernel<float>,
paddle::operators::CUDNNGridSampleOpKernel<double>);
REGISTER_OP_KERNEL(grid_sampler_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNGridSampleGradOpKernel<float>,
paddle::operators::CUDNNGridSampleGradOpKernel<double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h"
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class GridSampleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of GridSampleOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grid"),
"Input(Grid) of GridSampleOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output(Output) of GridSampleOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");
PADDLE_ENFORCE(x_dims.size() == 4,
"Input(X) of GridSampleOp should be 4-D Tensor.");
PADDLE_ENFORCE(grid_dims.size() == 4,
"Input(Grid) of GridSampleOp should be 4-D Tensor.");
PADDLE_ENFORCE(grid_dims[3] == 2, "Input(Grid) dims[3] should be 2.");
PADDLE_ENFORCE_EQ(grid_dims[0], x_dims[0],
"Input(X) and Input(Grid) dims[0] should be equal.");
PADDLE_ENFORCE_EQ(
grid_dims[1], x_dims[2],
"Input(X) dims[2] and Input(Grid) dims[1] should be equal.");
PADDLE_ENFORCE_EQ(
grid_dims[2], x_dims[3],
"Input(X) dims[3] and Input(Grid) dims[2] should be equal.");
ctx->SetOutputDim("Output", x_dims);
ctx->ShareLoD("X", "Output");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
class GridSampleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input data of GridSampleOp, "
"This is a 4-D tensor with shape of [N, C, H, W]");
AddInput(
"Grid",
"(Tensor) The input grid of GridSampleOp generated by AffineGridOp, "
"This is a 4-D tensor with shape of [N, H, W, 2] is the concatenation "
"of x and y coordinates with shape [N, H, W] in last dimention");
AddOutput("Output", "(Tensor) Output tensor with shape [N, C, H, W]");
AddAttr<bool>(
"use_cudnn",
"(bool, default true) Only used in cudnn kernel, need install cudnn")
.SetDefault(true);
AddComment(R"DOC(
This operation samples input X by using bilinear interpolation based on
flow field grid, which is usually gennerated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexng the 3rd
dimention (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
)DOC");
}
};
class GridSampleOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto input_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
}
if (ctx->HasOutput(framework::GradVarName("Grid"))) {
ctx->SetOutputDim(framework::GradVarName("Grid"), grid_dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
class GridSampleGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("grid_sampler_grad");
op->SetInput("X", Input("X"));
op->SetInput("Grid", Input("Grid"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Grid"), InputGrad("Grid"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(grid_sampler, ops::GridSampleOp, ops::GridSampleOpMaker,
ops::GridSampleGradMaker);
REGISTER_OPERATOR(grid_sampler_grad, ops::GridSampleOpGrad);
REGISTER_OP_CPU_KERNEL(
grid_sampler,
ops::GridSampleOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::GridSampleOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
grid_sampler_grad,
ops::GridSampleGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::GridSampleGradOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using Array3 = Eigen::DSizes<int64_t, 3>;
using Array4 = Eigen::DSizes<int64_t, 4>;
template <typename T>
static inline bool isInBound(T x, T y, T x_max, T y_max) {
if (x < 0 || x > x_max || y < 0 || y > y_max) {
return false;
}
return true;
}
template <typename T>
static void CalcGridLocations(const platform::CPUDeviceContext& ctx,
const Tensor& grid, Tensor* x_w, Tensor* x_e,
Tensor* y_n, Tensor* y_s, Tensor* d_w,
Tensor* d_e, Tensor* d_n, Tensor* d_s) {
auto& place = *ctx.eigen_device();
const int n = grid.dims()[0];
const int h = grid.dims()[1];
const int w = grid.dims()[2];
const T x_max = static_cast<T>(w - 1);
const T y_max = static_cast<T>(h - 1);
// split grid with shape (n, h, w, 2) into (x, y) by the 3rd Dim
Tensor grid_x, grid_y;
T* grid_x_data = grid_x.mutable_data<T>({n, h, w}, ctx.GetPlace());
T* grid_y_data = grid_y.mutable_data<T>({n, h, w}, ctx.GetPlace());
const T* grid_data = grid.data<T>();
for (int i = 0; i < n * h * w; i++) {
grid_x_data[i] = grid_data[2 * i];
grid_y_data[i] = grid_data[(2 * i) + 1];
}
Tensor ones;
ones.mutable_data<T>({n, h, w}, ctx.GetPlace());
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant(1.0);
// scale grid to [0, h-1/w-1]
auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
grid_x_t.device(place) = 0.5 * ((grid_x_t + ones_t) * x_max);
grid_y_t.device(place) = 0.5 * ((grid_y_t + ones_t) * y_max);
// calculate coords of 4 corner points
x_w->mutable_data<T>({n, h, w}, ctx.GetPlace());
x_e->mutable_data<T>({n, h, w}, ctx.GetPlace());
y_n->mutable_data<T>({n, h, w}, ctx.GetPlace());
y_s->mutable_data<T>({n, h, w}, ctx.GetPlace());
auto x_w_t = EigenTensor<T, 3>::From(*x_w);
auto x_e_t = EigenTensor<T, 3>::From(*x_e);
auto y_n_t = EigenTensor<T, 3>::From(*y_n);
auto y_s_t = EigenTensor<T, 3>::From(*y_s);
x_w_t.device(place) = grid_x_t.floor();
x_e_t.device(place) = x_w_t + ones_t;
y_n_t.device(place) = grid_y_t.floor();
y_s_t.device(place) = y_n_t + ones_t;
// calculate distances to 4 sides
d_w->mutable_data<T>({n, h, w}, ctx.GetPlace());
d_e->mutable_data<T>({n, h, w}, ctx.GetPlace());
d_n->mutable_data<T>({n, h, w}, ctx.GetPlace());
d_s->mutable_data<T>({n, h, w}, ctx.GetPlace());
auto d_w_t = EigenTensor<T, 3>::From(*d_w);
auto d_e_t = EigenTensor<T, 3>::From(*d_e);
auto d_n_t = EigenTensor<T, 3>::From(*d_n);
auto d_s_t = EigenTensor<T, 3>::From(*d_s);
d_w_t.device(place) = grid_x_t - x_w_t;
d_e_t.device(place) = x_e_t - grid_x_t;
d_n_t.device(place) = grid_y_t - y_n_t;
d_s_t.device(place) = y_s_t - grid_y_t;
}
template <typename T>
static void GetGridPointValue(const Tensor& input, Tensor* output,
const Tensor& x, const Tensor& y) {
const int n = input.dims()[0];
const int c = input.dims()[1];
const int h = input.dims()[2];
const int w = input.dims()[3];
auto x_t = EigenTensor<T, 3>::From(x);
auto y_t = EigenTensor<T, 3>::From(y);
auto output_t = EigenTensor<T, 4>::From(*output).setConstant((T)0);
auto input_t = EigenTensor<T, 4>::From(input);
for (int i = 0; i < n; i++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) {
for (int j = 0; j < c; j++) {
output_t(i, j, k, l) =
input_t(i, j, static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l))));
}
}
}
}
}
}
template <typename T>
static void GatherOutputGradToInputGrad(const Tensor& output_grad,
Tensor* input_grad, const Tensor& x,
const Tensor& y, const Tensor& d1,
const Tensor& d2) {
const int n = output_grad.dims()[0];
const int c = output_grad.dims()[1];
const int h = output_grad.dims()[2];
const int w = output_grad.dims()[3];
auto x_t = EigenTensor<T, 3>::From(x);
auto y_t = EigenTensor<T, 3>::From(y);
auto d1_t = EigenTensor<T, 3>::From(d1);
auto d2_t = EigenTensor<T, 3>::From(d2);
auto input_grad_t = EigenTensor<T, 4>::From(*input_grad);
auto output_grad_t = EigenTensor<T, 4>::From(output_grad);
for (int i = 0; i < n; i++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
if (isInBound(x_t(i, k, l), y_t(i, k, l), (T)(w - 1), (T)(h - 1))) {
for (int j = 0; j < c; j++) {
input_grad_t(i, j, static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l)))) +=
output_grad_t(i, j, k, l) * d1_t(i, k, l) * d2_t(i, k, l);
}
}
}
}
}
}
template <typename DeviceContext, typename T>
class GridSampleOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
const int n = input->dims()[0];
const int c = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
// calc locations and distances of 4 corner points
Tensor x_w, x_e, y_n, y_s;
Tensor d_w, d_e, d_n, d_s;
CalcGridLocations<T>(
ctx.template device_context<platform::CPUDeviceContext>(), *grid, &x_w,
&x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s);
auto* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), output,
static_cast<T>(0));
// calc 4 corner points value
Tensor v_wn, v_en, v_ws, v_es;
v_wn.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_en.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_ws.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_es.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
GetGridPointValue<T>(*input, &v_wn, x_w, y_n);
GetGridPointValue<T>(*input, &v_en, x_e, y_n);
GetGridPointValue<T>(*input, &v_ws, x_w, y_s);
GetGridPointValue<T>(*input, &v_es, x_e, y_s);
auto d_w_t = EigenTensor<T, 3>::From(d_w);
auto d_e_t = EigenTensor<T, 3>::From(d_e);
auto d_n_t = EigenTensor<T, 3>::From(d_n);
auto d_s_t = EigenTensor<T, 3>::From(d_s);
auto d_w_scaled_t =
d_w_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_e_scaled_t =
d_e_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_n_scaled_t =
d_n_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto d_s_scaled_t =
d_s_t.reshape(Array4(n, 1, h, w)).broadcast(Array4(1, c, 1, 1));
auto v_wn_t = EigenTensor<T, 4>::From(v_wn);
auto v_en_t = EigenTensor<T, 4>::From(v_en);
auto v_ws_t = EigenTensor<T, 4>::From(v_ws);
auto v_es_t = EigenTensor<T, 4>::From(v_es);
auto output_t = EigenTensor<T, 4>::From(*output);
// bilinear interpolaetion by 4 corner points
output_t.device(place) = v_wn_t * d_e_scaled_t * d_s_scaled_t +
v_en_t * d_w_scaled_t * d_s_scaled_t +
v_ws_t * d_e_scaled_t * d_n_scaled_t +
v_es_t * d_w_scaled_t * d_n_scaled_t;
}
};
template <typename DeviceContext, typename T>
class GridSampleGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
const int n = input->dims()[0];
const int c = input->dims()[1];
const int h = input->dims()[2];
const int w = input->dims()[3];
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
input_grad->mutable_data<T>({n, c, h, w}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), input_grad,
static_cast<T>(0));
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad->mutable_data<T>({n, h, w, 2}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), grid_grad,
static_cast<T>(0));
Tensor x_w, x_e, y_n, y_s;
Tensor d_w, d_e, d_n, d_s;
CalcGridLocations<T>(
ctx.template device_context<platform::CPUDeviceContext>(), *grid, &x_w,
&x_e, &y_n, &y_s, &d_w, &d_e, &d_n, &d_s);
// gather output grad value to input grad by corner point coords and weight
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_w, y_n, d_e,
d_s);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_w, y_s, d_e,
d_n);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_e, y_n, d_w,
d_s);
GatherOutputGradToInputGrad<T>(*output_grad, input_grad, x_e, y_s, d_w,
d_n);
// calc 4 corner points value
Tensor v_wn, v_en, v_ws, v_es;
v_wn.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_en.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_ws.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
v_es.mutable_data<T>({n, c, h, w}, ctx.GetPlace());
GetGridPointValue<T>(*input, &v_wn, x_w, y_n);
GetGridPointValue<T>(*input, &v_en, x_e, y_n);
GetGridPointValue<T>(*input, &v_ws, x_w, y_s);
GetGridPointValue<T>(*input, &v_es, x_e, y_s);
auto v_wn_t = EigenTensor<T, 4>::From(v_wn);
auto v_en_t = EigenTensor<T, 4>::From(v_en);
auto v_ws_t = EigenTensor<T, 4>::From(v_ws);
auto v_es_t = EigenTensor<T, 4>::From(v_es);
auto d_w_t = EigenTensor<T, 3>::From(d_w);
auto d_e_t = EigenTensor<T, 3>::From(d_e);
auto d_n_t = EigenTensor<T, 3>::From(d_n);
auto d_s_t = EigenTensor<T, 3>::From(d_s);
auto output_grad_t = EigenTensor<T, 4>::From(*output_grad);
Tensor grid_grad_x, grid_grad_y;
grid_grad_x.mutable_data<T>({n, h, w}, ctx.GetPlace());
grid_grad_y.mutable_data<T>({n, h, w}, ctx.GetPlace());
auto grid_grad_x_t = EigenTensor<T, 3>::From(grid_grad_x).setConstant(0.0);
auto grid_grad_y_t = EigenTensor<T, 3>::From(grid_grad_y).setConstant(0.0);
for (int i = 0; i < n; i++) {
for (int j = 0; j < c; j++) {
for (int k = 0; k < h; k++) {
for (int l = 0; l < w; l++) {
grid_grad_x_t(i, k, l) +=
((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) +
(v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) *
output_grad_t(i, j, k, l);
grid_grad_y_t(i, k, l) +=
((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) +
(v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) *
output_grad_t(i, j, k, l);
}
}
}
}
const T x_max = static_cast<T>(w - 1);
const T y_max = static_cast<T>(h - 1);
grid_grad_x_t = grid_grad_x_t * (x_max / (T)2);
grid_grad_y_t = grid_grad_y_t * (y_max / (T)2);
// gather grid_grad [x, y] in 3rd Dim
T* grid_grad_data = grid_grad->data<T>();
T* grid_grad_x_data = grid_grad_x.data<T>();
T* grid_grad_y_data = grid_grad_y.data<T>();
for (int i = 0; i < n * h * w; i++) {
grid_grad_data[2 * i] = grid_grad_x_data[i];
grid_grad_data[2 * i + 1] = grid_grad_y_data[i];
}
}
};
} // namespace operators
} // namespace paddle
......@@ -31,7 +31,7 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const framework::Tensor& input, const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_process,
framework::Tensor* output) {
bool exclusive, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -68,7 +68,8 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
pool_process.compute(input_data[h * input_width + w], &ele);
}
}
int pool_size = (hend - hstart) * (wend - wstart);
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
pool_process.finalize(static_cast<T>(pool_size), &ele);
output_data[ph * output_width + pw] = ele;
}
......@@ -93,7 +94,7 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const framework::Tensor& output, const framework::Tensor& output_grad,
const std::vector<int>& ksize, const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_grad_process,
framework::Tensor* input_grad) {
bool exclusive, framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
......@@ -124,7 +125,8 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int pool_size = (hend - hstart) * (wend - wstart);
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
float scale = 1.0 / pool_size;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
......@@ -249,7 +251,7 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const framework::Tensor& input, const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_process,
framework::Tensor* output) {
bool exclusive, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
......@@ -300,7 +302,9 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
}
}
int pool_size =
(dend - dstart) * (hend - hstart) * (wend - wstart);
exclusive
? (dend - dstart) * (hend - hstart) * (wend - wstart)
: ksize_depth * ksize_height * ksize_width;
pool_process.finalize(static_cast<T>(pool_size), &ele);
output_data[output_idx] = ele;
}
......@@ -326,7 +330,7 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const framework::Tensor& output, const framework::Tensor& output_grad,
const std::vector<int>& ksize, const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_grad_process,
framework::Tensor* input_grad) {
bool exclusive, framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
......@@ -369,7 +373,9 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
wstart = std::max(wstart, 0);
int pool_size =
(dend - dstart) * (hend - hstart) * (wend - wstart);
exclusive
? (dend - dstart) * (hend - hstart) * (wend - wstart)
: ksize_depth * ksize_height * ksize_width;
float scale = 1.0 / pool_size;
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
......
......@@ -29,7 +29,7 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
const int ksize_width, const int stride_height,
const int stride_width, const int padding_height,
const int padding_width, PoolProcess pool_process,
T* output_data) {
bool exclusive, T* output_data) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int pw = index % output_width;
......@@ -52,7 +52,8 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
pool_process.compute(input_data[h * input_width + w], &ele);
}
}
int pool_size = (hend - hstart) * (wend - wstart);
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
pool_process.finalize(static_cast<T>(pool_size), &ele);
output_data[index] = ele;
}
......@@ -65,7 +66,7 @@ __global__ void KernelPool2DGrad(
const int input_width, const int output_height, const int output_width,
const int ksize_height, const int ksize_width, const int stride_height,
const int stride_width, const int padding_height, const int padding_width,
PoolProcess pool_process, T* input_grad) {
PoolProcess pool_process, bool exclusive, T* input_grad) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int offsetW = index % input_width + padding_width;
......@@ -95,7 +96,8 @@ __global__ void KernelPool2DGrad(
int wend = min(wstart + ksize_width, input_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
int pool_size = (hend - hstart) * (wend - wstart);
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
int output_sub_idx = ph * output_width + pw;
pool_process.compute(input, output_data[output_sub_idx],
output_grad[output_sub_idx],
......@@ -163,7 +165,7 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const framework::Tensor& input, const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_process,
framework::Tensor* output) {
bool exclusive, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
......@@ -189,7 +191,8 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, input_channels, input_height, input_width,
output_height, output_width, ksize_height, ksize_width, stride_height,
stride_width, padding_height, padding_width, pool_process, output_data);
stride_width, padding_height, padding_width, pool_process, exclusive,
output_data);
}
};
......@@ -208,7 +211,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_process,
framework::Tensor* input_grad) {
bool exclusive, framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_height = input.dims()[2];
......@@ -236,7 +239,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
nthreads, input_data, output_data, output_grad_data, input_channels,
input_height, input_width, output_height, output_width, ksize_height,
ksize_width, stride_height, stride_width, padding_height, padding_width,
pool_process, input_grad_data);
pool_process, exclusive, input_grad_data);
}
};
......@@ -313,16 +316,14 @@ template class Pool2dGradFunctor<platform::CUDADeviceContext,
double>;
template <typename PoolProcess, typename T>
__global__ void KernelPool3D(const int nthreads, const T* input_data,
const int channels, const int input_depth,
const int input_height, const int input_width,
const int output_depth, const int output_height,
const int output_width, const int ksize_depth,
const int ksize_height, const int ksize_width,
const int stride_depth, const int stride_height,
const int stride_width, const int padding_depth,
const int padding_height, const int padding_width,
PoolProcess pool_process, T* output_data) {
__global__ void KernelPool3D(
const int nthreads, const T* input_data, const int channels,
const int input_depth, const int input_height, const int input_width,
const int output_depth, const int output_height, const int output_width,
const int ksize_depth, const int ksize_height, const int ksize_width,
const int stride_depth, const int stride_height, const int stride_width,
const int padding_depth, const int padding_height, const int padding_width,
PoolProcess pool_process, bool exclusive, T* output_data) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int pw = index % output_width;
......@@ -351,7 +352,9 @@ __global__ void KernelPool3D(const int nthreads, const T* input_data,
}
}
}
int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
int pool_size = exclusive
? (dend - dstart) * (hend - hstart) * (wend - wstart)
: ksize_depth * ksize_height * ksize_width;
pool_process.finalize(static_cast<T>(pool_size), &ele);
output_data[index] = ele;
}
......@@ -366,7 +369,7 @@ __global__ void KernelPool3DGrad(
const int ksize_height, const int ksize_width, const int stride_depth,
const int stride_height, const int stride_width, const int padding_depth,
const int padding_height, const int padding_width, PoolProcess pool_process,
T* input_grad) {
bool exclusive, T* input_grad) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int offsetW = index % input_width + padding_width;
......@@ -409,7 +412,9 @@ __global__ void KernelPool3DGrad(
dstart = max(dstart, 0);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
int pool_size =
exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
: ksize_depth * ksize_height * ksize_width;
int output_sub_idx = (pd * output_height + ph) * output_width + pw;
pool_process.compute(input, output_data[output_sub_idx],
output_grad[output_sub_idx],
......@@ -484,7 +489,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const framework::Tensor& input, const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_process,
framework::Tensor* output) {
bool exclusive, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_depth = input.dims()[2];
......@@ -517,7 +522,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
nthreads, input_data, input_channels, input_depth, input_height,
input_width, output_depth, output_height, output_width, ksize_depth,
ksize_height, ksize_width, stride_depth, stride_height, stride_width,
padding_depth, padding_height, padding_width, pool_process,
padding_depth, padding_height, padding_width, pool_process, exclusive,
output_data);
}
};
......@@ -537,7 +542,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_process,
framework::Tensor* input_grad) {
bool exclusive, framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_channels = input.dims()[1];
const int input_depth = input.dims()[2];
......@@ -573,7 +578,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
input_depth, input_height, input_width, output_depth, output_height,
output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
stride_height, stride_width, padding_depth, padding_height,
padding_width, pool_process, input_grad_data);
padding_width, pool_process, exclusive, input_grad_data);
}
};
......
......@@ -89,7 +89,7 @@ class Pool2dFunctor {
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute,
framework::Tensor* output);
bool exclusive, framework::Tensor* output);
};
template <typename DeviceContext, typename PoolProcess, typename T>
......@@ -101,7 +101,7 @@ class Pool2dGradFunctor {
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute,
framework::Tensor* input_grad);
bool exclusive, framework::Tensor* input_grad);
};
template <typename DeviceContext, class T>
......@@ -123,7 +123,7 @@ class Pool3dFunctor {
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute,
framework::Tensor* output);
bool exclusive, framework::Tensor* output);
};
template <typename DeviceContext, typename PoolProcess, typename T>
......@@ -135,7 +135,7 @@ class Pool3dGradFunctor {
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, PoolProcess pool_compute,
framework::Tensor* input_grad);
bool exclusive, framework::Tensor* input_grad);
};
template <typename DeviceContext, class T>
......
......@@ -41,6 +41,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
T *output_data = output->mutable_data<T>(ctx.GetPlace());
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
bool exclusive = ctx.Attr<bool>("exclusive");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
......@@ -72,7 +73,8 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
if (pooling_type == "max") {
pooling_mode = PoolingMode::kMaximum;
} else {
pooling_mode = PoolingMode::kAverage;
pooling_mode = exclusive ? PoolingMode::kAverageExclusive
: PoolingMode::kAverageInclusive;
}
cudnnPoolingDescriptor_t cudnn_pool_desc =
......@@ -101,6 +103,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
Tensor *input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
bool exclusive = ctx.Attr<bool>("exclusive");
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
......@@ -141,7 +144,8 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
pooling_mode = PoolingMode::kMaximum;
}
} else {
pooling_mode = PoolingMode::kAverage;
pooling_mode = exclusive ? PoolingMode::kAverageExclusive
: PoolingMode::kAverageInclusive;
}
cudnnPoolingDescriptor_t cudnn_pool_desc =
......
......@@ -180,6 +180,12 @@ void Pool2dOpMaker::Make() {
"operator."
"If global_pooling = true, paddings and ksize will be ignored.")
.SetDefault({0, 0});
AddAttr<bool>(
"exclusive",
"(bool, default True) When true, will exclude the zero-padding in the "
"averaging calculating, otherwise, include the zero-padding. Note, it "
"is only used when pooling_type is avg. The defalut is True.")
.SetDefault(true);
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
......@@ -236,6 +242,23 @@ Example:
W_{out} = \\frac{(W_{in} - ksize[1] + 2 * paddings[1] + strides[1] - 1)}{strides[1]} + 1
$$
For exclusive = true:
$$
hstart = i * strides[0] - paddings[0]
hend = hstart + ksize[0]
wstart = j * strides[1] - paddings[1]
wend = wstart + ksize[1]
Output(i ,j) = \\frac{sum(Input[hstart:hend, wstart:wend])}{ksize[0] * ksize[1]}
$$
For exclusive = false:
$$
hstart = max(0, i * strides[0] - paddings[0])
hend = min(H, hstart + ksize[0])
wstart = max(0, j * strides[1] - paddings[1])
wend = min(W, wstart + ksize[1])
Output(i ,j) = \\frac{sum(Input[hstart:hend, wstart:wend])}{(hend - hstart) * (wend - wstart)}
$$
)DOC");
}
......@@ -283,6 +306,12 @@ void Pool3dOpMaker::Make() {
"If global_pooling = true, ksize and paddings will be ignored.")
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"exclusive",
"(bool, default True) When true, will exclude the zero-padding in the "
"averaging calculating, otherwise, include the zero-padding. Note, it "
"is only used when pooling_type is avg. The defalut is True.")
.SetDefault(true);
AddAttr<bool>(
"use_cudnn",
......
......@@ -69,6 +69,7 @@ class PoolKernel : public framework::OpKernel<T> {
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
bool exclusive = context.Attr<bool>("exclusive");
if (context.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
......@@ -84,7 +85,7 @@ class PoolKernel : public framework::OpKernel<T> {
pool2d_forward;
paddle::operators::math::MaxPool<T> pool_process;
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
out);
true, out);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor<
......@@ -92,7 +93,7 @@ class PoolKernel : public framework::OpKernel<T> {
pool2d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
out);
exclusive, out);
}
} break;
case 3: {
......@@ -102,14 +103,14 @@ class PoolKernel : public framework::OpKernel<T> {
pool3d_forward;
paddle::operators::math::MaxPool<T> pool_process;
pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
out);
true, out);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool3dFunctor<
DeviceContext, paddle::operators::math::AvgPool<T>, T>
pool3d_forward;
paddle::operators::math::AvgPool<T> pool_process;
pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
out);
exclusive, out);
}
} break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
......@@ -131,6 +132,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
bool exclusive = context.Attr<bool>("exclusive");
if (context.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) {
......@@ -157,7 +159,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
pool2d_backward;
paddle::operators::math::AvgPoolGrad<T> pool_process;
pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
paddings, pool_process, in_x_grad);
paddings, pool_process, exclusive, in_x_grad);
}
} break;
case 3: {
......@@ -172,7 +174,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
pool3d_backward;
paddle::operators::math::AvgPoolGrad<T> pool_process;
pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
paddings, pool_process, in_x_grad);
paddings, pool_process, exclusive, in_x_grad);
}
} break;
default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
......
......@@ -44,6 +44,12 @@ class SoftmaxWithCrossEntropyOpMaker
"(bool, default: false), A flag to indicate whether to interpretate "
"the given labels as soft labels.")
.SetDefault(false);
AddAttr<bool>(
"numeric_stable_mode",
"(bool, default: false), A flag to indicate whether to use more "
"numerically stable algorithm. This flag is only valid when "
"soft_label is false and GPU is used.")
.SetDefault(false);
AddAttr<int>(
"ignore_index",
"(int, default -100), Specifies a target value that is ignored and"
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <cub/cub.cuh>
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
......@@ -117,8 +118,8 @@ using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
// Make sure that BlockDim <= feature_size
// This kernel is used to calculate the max element of each row
template <typename T, int BlockDim>
__global__ void RowReductionForMax(const T* logits_data, T* max_data,
int feature_size) {
static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
int feature_size) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
auto beg_idx = feature_size * blockIdx.x + threadIdx.x;
......@@ -141,9 +142,10 @@ __global__ void RowReductionForMax(const T* logits_data, T* max_data,
}
// Make sure that BlockDim <= feature_size
template <typename T, int BlockDim>
__global__ void RowReductionForDiffMaxSum(const T* logits_data, T* max_data,
T* softmax, int feature_size) {
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
T* max_data, T* softmax,
int feature_size) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
auto beg_idx = feature_size * blockIdx.x + threadIdx.x;
......@@ -153,24 +155,34 @@ __global__ void RowReductionForDiffMaxSum(const T* logits_data, T* max_data,
softmax[beg_idx] = logits_data[beg_idx] - block_max;
T diff_max_sum = real_exp(softmax[beg_idx]);
beg_idx += BlockDim;
while (beg_idx < end_idx) {
softmax[beg_idx] = logits_data[beg_idx] - block_max;
diff_max_sum += real_exp(softmax[beg_idx]);
beg_idx += BlockDim;
auto idx = beg_idx + BlockDim;
while (idx < end_idx) {
softmax[idx] = logits_data[idx] - block_max;
diff_max_sum += real_exp(softmax[idx]);
idx += BlockDim;
}
diff_max_sum =
BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
if (threadIdx.x == 0) max_data[blockIdx.x] = real_log(diff_max_sum);
if (!CalculateLogSoftmax) return;
__syncthreads();
diff_max_sum = max_data[blockIdx.x];
softmax[beg_idx] -= diff_max_sum;
beg_idx += BlockDim;
while (beg_idx < end_idx) {
softmax[beg_idx] -= diff_max_sum;
beg_idx += BlockDim;
}
if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
}
// Make sure that BlockDim <= feature_size
template <typename T, int BlockDim>
__global__ void RowReductionForSoftmaxAndCrossEntropy(const T* logits_data,
const T* labels_data,
T* loss_data, T* softmax,
int feature_size) {
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
const T* logits_data, const T* labels_data, T* loss_data, T* softmax,
int feature_size) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
auto beg_idx = feature_size * blockIdx.x + threadIdx.x;
......@@ -194,11 +206,134 @@ __global__ void RowReductionForSoftmaxAndCrossEntropy(const T* logits_data,
}
template <typename T>
__global__ void SetSoftmaxToOneWhenFeatureSizeIsOne(T* out, int batch_size) {
struct HardLabelSoftmaxWithCrossEntropyFunctor {
public:
HardLabelSoftmaxWithCrossEntropyFunctor(const T* logits,
const int64_t* labels, T* loss,
T* log_softmax, int feature_size)
: logits_(logits),
labels_(labels),
loss_(loss),
log_softmax_(log_softmax),
feature_size_(feature_size) {}
__device__ void operator()(int idx) const {
auto row_idx = idx / feature_size_;
auto col_idx = idx % feature_size_;
if (col_idx != labels_[row_idx]) {
log_softmax_[idx] = real_exp(log_softmax_[idx]);
} else {
auto softmax = log_softmax_[idx];
log_softmax_[idx] = real_exp(softmax);
loss_[row_idx] = -softmax;
}
}
private:
const T* logits_;
const int64_t* labels_;
T* loss_;
T* log_softmax_;
int feature_size_;
};
template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
public:
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const T* logits,
const int64_t* labels,
T* loss, T* log_softmax,
int feature_size,
int ignore_idx)
: logits_(logits),
labels_(labels),
loss_(loss),
log_softmax_(log_softmax),
feature_size_(feature_size),
ignore_idx_(ignore_idx) {}
__device__ void operator()(int idx) const {
auto row_idx = idx / feature_size_;
auto col_idx = idx % feature_size_;
if (col_idx != labels_[row_idx] || col_idx == ignore_idx_) {
log_softmax_[idx] = real_exp(log_softmax_[idx]);
} else {
auto softmax = log_softmax_[idx];
log_softmax_[idx] = real_exp(softmax);
loss_[row_idx] = -softmax;
}
}
private:
const T* logits_;
const int64_t* labels_;
T* loss_;
T* log_softmax_;
int feature_size_;
int ignore_idx_;
};
template <typename T>
static __global__ void SetSoftmaxToOneWhenFeatureSizeIsOne(T* out,
int batch_size) {
auto idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < batch_size) out[idx] = static_cast<T>(1);
}
template <typename T>
static void HardLabelSoftmaxWithCrossEntropy(
const platform::CUDADeviceContext& ctx, const T* logits_data,
const int64_t* labels_data, T* loss_data, T* softmax_data, int batch_size,
int feature_size, int ignore_idx) {
constexpr int kMaxBlockDim = 512;
int block_dim = feature_size >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(feature_size)));
auto stream = ctx.stream();
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: { \
RowReductionForMax<T, BlockDim><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, feature_size); \
RowReductionForDiffMaxSum<T, BlockDim, \
true><<<batch_size, BlockDim, 0, stream>>>( \
logits_data, loss_data, softmax_data, feature_size); \
platform::ForRange<platform::CUDADeviceContext> for_range( \
ctx, batch_size* feature_size); \
if (ignore_idx >= 0 && ignore_idx < feature_size) { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx<T>( \
logits_data, labels_data, loss_data, softmax_data, feature_size, \
ignore_idx)); \
} else { \
for_range(HardLabelSoftmaxWithCrossEntropyFunctor<T>( \
logits_data, labels_data, loss_data, softmax_data, feature_size)); \
} \
} break
switch (block_dim) {
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
case 1:
SetSoftmaxToOneWhenFeatureSizeIsOne<<<(batch_size + kMaxBlockDim - 1) /
kMaxBlockDim,
kMaxBlockDim, 0, stream>>>(
softmax_data, batch_size);
cudaMemsetAsync(loss_data, 0, batch_size * sizeof(T), stream);
break;
default:
PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
break;
}
#undef CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}
template <typename T>
static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
const T* labels_data,
......@@ -237,7 +372,7 @@ static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
kMaxBlockDim,
kMaxBlockDim, 0, stream>>>(
softmax_data, batch_size);
cudaMemsetAsync(loss_data, 0, batch_size, stream);
cudaMemsetAsync(loss_data, 0, batch_size * sizeof(T), stream);
break;
default:
PADDLE_THROW("BlockDim must be 2^n in softmax_with_cross_entropy_op");
......@@ -272,11 +407,21 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
logits_data, labels_data, softmax_data, loss_data, batch_size,
feature_size, context.cuda_device_context().stream());
} else {
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(), logits,
softmax);
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
context.cuda_device_context(), loss, softmax, labels, false,
ignore_index);
if (!context.Attr<bool>("numeric_stable_mode")) {
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(), logits,
softmax);
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
context.cuda_device_context(), loss, softmax, labels, false,
ignore_index);
} else {
int batch_size = logits->dims()[0];
int feature_size = logits->dims()[1];
auto* logits_data = logits->data<T>();
auto* labels_data = labels->data<int64_t>();
HardLabelSoftmaxWithCrossEntropy<T>(
context.cuda_device_context(), logits_data, labels_data, loss_data,
softmax_data, batch_size, feature_size, ignore_index);
}
}
}
};
......
......@@ -56,12 +56,14 @@ class SppKernel : public framework::OpKernel<T> {
math::Pool2dFunctor<DeviceContext, math::MaxPool<T>, T> pool_forward;
math::MaxPool<T> max_process;
pool_forward(context.template device_context<DeviceContext>(), *in_x,
kernel_size, strides, paddings, max_process, &out_level);
kernel_size, strides, paddings, max_process, true,
&out_level);
} else if (pooling_type == "avg") {
math::Pool2dFunctor<DeviceContext, math::AvgPool<T>, T> pool_forward;
math::AvgPool<T> avg_process;
pool_forward(context.template device_context<DeviceContext>(), *in_x,
kernel_size, strides, paddings, avg_process, &out_level);
kernel_size, strides, paddings, avg_process, true,
&out_level);
}
// flatten pooling output shape
int output_flatten_w = in_x->dims()[1] * bins * bins;
......@@ -154,7 +156,7 @@ class SppGradKernel : public framework::OpKernel<T> {
math::AvgPoolGrad<T> avg_process;
pool_backward(context.template device_context<DeviceContext>(), *in_x,
*&out_level, *&outgrad_level, kernel_size, strides,
paddings, avg_process, in_x_grad);
paddings, avg_process, true, in_x_grad);
}
}
}
......
......@@ -76,8 +76,9 @@ enum class DataLayout { // Not use
enum class PoolingMode {
kMaximum,
kAverage,
kMaximumDeterministic,
kAverageExclusive,
kAverageInclusive,
};
#if CUDNN_VERSION < 6000
......@@ -91,8 +92,10 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
switch (mode) {
case PoolingMode::kMaximumDeterministic:
return CUDNN_POOLING_MAX;
case PoolingMode::kAverage:
case PoolingMode::kAverageExclusive:
return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
case PoolingMode::kAverageInclusive:
return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
case PoolingMode::kMaximum:
return CUDNN_POOLING_MAX;
default:
......@@ -105,8 +108,10 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
switch (mode) {
case PoolingMode::kMaximumDeterministic:
return CUDNN_POOLING_MAX_DETERMINISTIC;
case PoolingMode::kAverage:
case PoolingMode::kAverageExclusive:
return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
case PoolingMode::kAverageInclusive:
return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
case PoolingMode::kMaximum:
return CUDNN_POOLING_MAX;
default:
......
......@@ -821,6 +821,13 @@ All parameter, weight, gradient are variables in Paddle.
[](BuildStrategy &self, bool b) {
self.enable_data_balance_ = b;
}) // FIXME(chengudo): enable_data_balance seems not important
.def_property("enable_sequential_execution",
[](const BuildStrategy &self) {
return self.enable_sequential_execution_;
},
[](BuildStrategy &self, bool b) {
self.enable_sequential_execution_ = b;
})
.def_property(
"fuse_elewise_add_act_ops",
[](const BuildStrategy &self) {
......
......@@ -1586,8 +1586,7 @@ class DynamicRNN(object):
self.lod_rank_table = None
self.max_seq_len = None
self.step_idx = None
self.zero_idx = fill_constant(
shape=[1], value=0, dtype='int64', force_cpu=True)
self.zero_idx = None
self.mem_dict = dict()
self.output_array = []
self.outputs = []
......@@ -1792,6 +1791,7 @@ class DynamicRNN(object):
"""
self._assert_in_rnn_block_('memory')
self._init_zero_idx_()
if init is not None:
if not isinstance(init, Variable):
raise TypeError(
......@@ -1905,6 +1905,22 @@ class DynamicRNN(object):
array_write(x=each, i=self.step_idx, array=outside_array)
self.output_array.append(outside_array)
def _init_zero_idx_(self):
if self.zero_idx is None:
parent_block = self._parent_block_()
self.zero_idx = parent_block.create_var(
name=unique_name.generate('zero_idx'), dtype='int64')
parent_block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [self.zero_idx]},
attrs={
'shape': [1],
'dtype': self.zero_idx.dtype,
'value': float(0),
'force_cpu': True
})
def _parent_block_(self):
prog = self.helper.main_program
parent_idx = prog.current_block().parent_idx
......
......@@ -158,6 +158,7 @@ __all__ = [
'sequence_reverse',
'affine_channel',
'hash',
'grid_sampler',
'log_loss',
'add_position_encoding',
]
......@@ -2101,7 +2102,8 @@ def pool2d(input,
global_pooling=False,
use_cudnn=True,
ceil_mode=False,
name=None):
name=None,
exclusive=True):
"""
${comment}
......@@ -2115,11 +2117,13 @@ def pool2d(input,
pool_type: ${pooling_type_comment}
pool_stride (int): stride of the pooling layer.
pool_padding (int): padding size.
global_pooling: ${global_pooling_comment}
use_cudnn: ${use_cudnn_comment}
ceil_mode: ${ceil_mode_comment}
global_pooling (bool): ${global_pooling_comment}
use_cudnn (bool): ${use_cudnn_comment}
ceil_mode (bool): ${ceil_mode_comment}
name (str|None): A name for this layer(optional). If set None, the
layer will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
Returns:
Variable: The pooling result.
......@@ -2177,7 +2181,8 @@ def pool2d(input,
"paddings": pool_padding,
"use_cudnn": use_cudnn,
"ceil_mode": ceil_mode,
"use_mkldnn": False
"use_mkldnn": False,
"exclusive": exclusive,
})
return pool_out
......@@ -2191,7 +2196,8 @@ def pool3d(input,
global_pooling=False,
use_cudnn=True,
ceil_mode=False,
name=None):
name=None,
exclusive=True):
"""
This function adds the operator for pooling in 3-dimensions, using the
pooling configurations mentioned in input parameters.
......@@ -2207,6 +2213,8 @@ def pool3d(input,
ceil_mode (bool): ${ceil_mode_comment}
name (str): A name for this layer(optional). If set None, the layer
will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
Returns:
Variable: output of pool3d layer.
......@@ -2245,7 +2253,8 @@ def pool3d(input,
"paddings": pool_padding,
"use_cudnn": use_cudnn,
"ceil_mode": ceil_mode,
"use_mkldnn": False
"use_mkldnn": False,
"exclusive": exclusive,
})
return pool_out
......@@ -4713,7 +4722,8 @@ def multiplex(inputs, index):
def softmax_with_cross_entropy(logits,
label,
soft_label=False,
ignore_index=-100):
ignore_index=-100,
numeric_stable_mode=False):
"""
**Softmax With Cross Entropy Operator.**
......@@ -4747,6 +4757,18 @@ def softmax_with_cross_entropy(logits,
\\left(\\text{logit}_i - \\log\\left(\\sum_{i=0}^{K}
\\exp(\\text{logit}_i)\\right)\\right), j = 1,...,K
3) If numeric_stable_mode is True, softmax is calculated first by:
.. math::
max_j = \\max_{i=0}^{K}{\\text{logit}_i}
log\\_max\\_sum_j = \\log\\sum_{i=0}^{K}\\exp(logit_i - max_j)
softmax_j = \\exp(logit_j - max_j - {log\\_max\\_sum}_j)
and then cross entropy loss is calculated by softmax and label.
Args:
logits (Variable): The unscaled log probabilities, which is a 2-D tensor
with shape [N x K]. N is the batch_size, and K is the class number.
......@@ -4758,6 +4780,13 @@ def softmax_with_cross_entropy(logits,
ignore_index (int): Specifies a target value that is ignored and does
not contribute to the input gradient. Only valid
if soft_label is set to False. Default: -100
numeric_stable_mode (bool): A flag to indicate whether to use a more
numerically stable algorithm. Only valid
when soft_label is False and GPU is used.
When soft_label is True or CPU is used,
the algorithm is always numerically stable.
Note that the speed may be slower when use
stable algorithm. Default: False
Returns:
Variable: The cross entropy loss is a 2-D tensor with shape [N x 1].
......@@ -4780,8 +4809,11 @@ def softmax_with_cross_entropy(logits,
'Label': label},
outputs={'Softmax': softmax,
'Loss': loss},
attrs={'soft_label': soft_label,
'ignore_index': ignore_index})
attrs={
'soft_label': soft_label,
'ignore_index': ignore_index,
'numeric_stable_mode': numeric_stable_mode
})
return loss
......@@ -7712,19 +7744,59 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
def hash(input, hash_size, num_hash=1, name=None):
"""
hash the input
Args:
input (Variable): The input variable which is a one-hot word.
hash_size (int): The space size for hash algorithm.
Hash the input to an integer whose value is less than the given hash size.
The hash algorithm we used was xxHash - Extremely fast hash algorithm
(https://github.com/Cyan4973/xxHash/tree/v0.6.5)
A simple example as below:
.. code-block:: text
Given:
# shape [2, 2]
input.data = [
[[1], [2]],
[[3], [4]],
]
input.lod = [[0, 2]]
hash_size = 10000
num_hash = 4
Then:
Hash op will take all number in input's 2nd dimension as hash algorithm's
input for each time. Each input will be hashed for 4 times, and get an
array whose length is 4. Each value in the array ranges from 0 to 9999.
# shape [2, 4]
output.data = [
[[9662], [9217], [1129], [8487]],
[[8310], [1327], [1654], [4567]],
]
output.lod = [[0, 2]]
Args:
input (Variable): The input variable which is a one-hot word. The
dimensions of the input variable must be 2.
hash_size (int): The space size for hash algorithm. The output value
will keep in the range:math:`[0, hash_size - 1]`.
num_hash (int): The times of hash, default 1.
name (str, default None): The name of this layer.
Returns:
Variable: The hash result variable which is a LoDTensor.
Examples:
.. code-block:: python
word_dict = paddle.dataset.imdb.word_dict()
x = fluid.layers.data(shape[1], dtype='int32', lod_level=1)
out = fluid.layers.hash(input=x, len(word_dict))
Returns:
Variable: The hash result variable which is a LoDTensor.
Examples:
.. code-block:: python
word_dict = paddle.dataset.imdb.word_dict()
x = fluid.layers.data(shape[1], dtype='int32', lod_level=1)
out = fluid.layers.hash(input=x, num_hash=4, hash_size=1000)
"""
helper = LayerHelper('hash', **locals())
out = helper.create_variable_for_type_inference(
......@@ -7738,6 +7810,87 @@ def hash(input, hash_size, num_hash=1, name=None):
return out
@templatedoc()
def grid_sampler(x, grid, name=None):
"""
This operation samples input X by using bilinear interpolation based on
flow field grid, which is usually gennerated by affine_grid. The grid of
shape [N, H, W, 2] is the concatenation of (grid_x, grid_y) coordinates
with shape [N, H, W] each, where grid_x is indexing the 4th dimension
(in width dimension) of input data x and grid_y is indexng the 3rd
dimention (in height dimension), finally results is the bilinear
interpolation value of 4 nearest corner points.
Step 1:
Get (x, y) grid coordinates and scale to [0, H-1/W-1].
grid_x = 0.5 * (grid[:, :, :, 0] + 1) * (W - 1)
grid_y = 0.5 * (grid[:, :, :, 1] + 1) * (H - 1)
Step 2:
Indices input data X with grid (x, y) in each [H, W] area, and bilinear
interpolate point value by 4 nearest points.
wn ------- y_n ------- en
| | |
| d_n |
| | |
x_w --d_w-- grid--d_e-- x_e
| | |
| d_s |
| | |
ws ------- y_s ------- wn
x_w = floor(x) // west side x coord
x_e = x_w + 1 // east side x coord
y_n = floor(y) // north side y coord
y_s = y_s + 1 // south side y coord
d_w = grid_x - x_w // distance to west side
d_e = x_e - grid_x // distance to east side
d_n = grid_y - y_n // distance to north side
d_s = y_s - grid_y // distance to south side
wn = X[:, :, y_n, x_w] // north-west point value
en = X[:, :, y_n, x_e] // north-east point value
ws = X[:, :, y_s, x_w] // south-east point value
es = X[:, :, y_s, x_w] // north-east point value
output = wn * d_e * d_s + en * d_w * d_s
+ ws * d_e * d_n + es * d_w * d_n
Args:
x(Variable): Input data of shape [N, C, H, W].
grid(Variable): Input grid tensor of shape [N, H, W, 2].
name (str, default None): The name of this layer.
Returns:
out(Variable): Output of shape [N, C, H, W] data samples input X
using bilnear interpolation based on input grid.
Exmples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[3, 10, 32, 32], dtype='float32')
theta = fluid.layers.data(name='theta', shape=[3, 2, 3], dtype='float32')
grid = fluid.layers.affine_grid(input=theta, size=[3, 10, 32, 32]})
out = fluid.layers.grid_sampler(x=x, grid=grid)
"""
helper = LayerHelper("grid_sampler", **locals())
if not isinstance(x, Variable):
return ValueError("The x should be a Variable")
if not isinstance(grid, Variable):
return ValueError("The grid should be a Variable")
out = helper.create_variable_for_type_inference(x.dtype)
ipts = {'X': x, 'Grid': grid}
helper.append_op(type='grid_sampler', inputs=ipts, outputs={'Output': out})
return out
def log_loss(input, label, epsilon=1e-4, name=None):
"""
**Negative Log Loss Layer**
......
......@@ -40,7 +40,8 @@ class TestParallelExecutorBase(unittest.TestCase):
use_reduce=False,
fuse_elewise_add_act_ops=False,
optimizer=fluid.optimizer.Adam,
use_fast_executor=False):
use_fast_executor=False,
enable_sequential_execution=False):
def run_executor(exe, feed, fetch_list, program=None):
if isinstance(exe, fluid.ParallelExecutor):
res = exe.run(fetch_list=fetch_list, feed=feed)
......@@ -80,6 +81,7 @@ class TestParallelExecutorBase(unittest.TestCase):
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.enable_sequential_execution = enable_sequential_execution
if use_parallel_executor:
exe = fluid.ParallelExecutor(
......
......@@ -72,6 +72,7 @@ class TestDistSaveLoadDense2x2(TestDistBase):
self.assertAlmostEqual(local_np.all(), train1_np.all(), delta=delta)
self.assertAlmostEqual(train0_np.all(), train1_np.all(), delta=delta)
@unittest.skip(reason="CI fail")
def test_dist(self):
need_envs = {
"IS_DISTRIBUTED": '0',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
def AffineGrid(theta, size):
n = size[0]
h = size[2]
w = size[3]
h_idx = np.repeat(
np.linspace(-1, 1, h)[np.newaxis, :], w, axis=0).T[:, :, np.newaxis]
w_idx = np.repeat(
np.linspace(-1, 1, w)[np.newaxis, :], h, axis=0)[:, :, np.newaxis]
grid = np.concatenate(
[w_idx, h_idx, np.ones([h, w, 1])], axis=2) # h * w * 3
grid = np.repeat(grid[np.newaxis, :], size[0], axis=0) # n * h * w *3
ret = np.zeros([n, h * w, 2])
theta = theta.transpose([0, 2, 1])
for i in range(len(theta)):
ret[i] = np.dot(grid[i].reshape([h * w, 3]), theta[i])
return ret.reshape([n, h, w, 2]).astype("float32")
def getGridPointValue(data, x, y):
data_shape = data.shape
N = data_shape[0]
H = data_shape[2]
W = data_shape[3]
out = np.zeros(data_shape, dtype='float')
for i in range(N):
for j in range(H):
for k in range(W):
if y[i, j, k] < 0 or y[i, j, k] > H - 1 or x[i, j, k] < 0 or x[
i, j, k] > W - 1:
out[i, :, j, k] = 0
else:
out[i, :, j, k] = data[i, :, y[i, j, k], x[i, j, k]]
return out
def GridSampler(data, grid):
dims = data.shape
N = dims[0]
C = dims[1]
H = dims[2]
W = dims[3]
x = grid[:, :, :, 0]
y = grid[:, :, :, 1]
y_max = H - 1
x_max = W - 1
x = 0.5 * ((x.astype('float32') + 1.0) * x_max)
y = 0.5 * ((y.astype('float32') + 1.0) * y_max)
x0 = np.floor(x).astype('int32')
x1 = x0 + 1
y0 = np.floor(y).astype('int32')
y1 = y0 + 1
wa = np.tile(((x1 - x) * (y1 - y)).reshape((N, 1, H, W)), (1, C, 1, 1))
wb = np.tile(((x1 - x) * (y - y0)).reshape((N, 1, H, W)), (1, C, 1, 1))
wc = np.tile(((x - x0) * (y1 - y)).reshape((N, 1, H, W)), (1, C, 1, 1))
wd = np.tile(((x - x0) * (y - y0)).reshape((N, 1, H, W)), (1, C, 1, 1))
va = getGridPointValue(data, x0, y0)
vb = getGridPointValue(data, x0, y1)
vc = getGridPointValue(data, x1, y0)
vd = getGridPointValue(data, x1, y1)
out = (wa * va + wb * vb + wc * vc + wd * vd).astype('float32')
return out
class TestGridSamplerOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'grid_sampler'
x = np.random.randint(0, 255, self.x_shape).astype('float32')
theta = np.zeros(self.theta_shape).astype('float32')
for i in range(self.theta_shape[0]):
for j in range(2):
for k in range(3):
theta[i, j, k] = np.random.rand(1)[0]
grid = AffineGrid(theta, self.x_shape)
self.inputs = {'X': x, 'Grid': grid}
self.attrs = {'use_cudnn': True}
self.outputs = {'Output': GridSampler(x, grid)}
def test_check_output(self):
self.check_output(atol=1e-3)
def test_check_grad_normal(self):
self.check_grad(['X', 'Grid'], 'Output', max_relative_error=0.61)
def initTestCase(self):
self.x_shape = (2, 5, 7, 3)
self.grid_shape = (2, 7, 3, 2)
self.theta_shape = (2, 2, 3)
if __name__ == "__main__":
unittest.main()
......@@ -865,6 +865,15 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out)
print(str(program))
def test_grid_sampler(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[3, 5, 7], dtype='float32')
grid = layers.data(name='grid', shape=[5, 7, 2], dtype='float32')
out = layers.grid_sampler(x, grid)
self.assertIsNotNone(out)
print(str(program))
def test_affine_grid(self):
program = Program()
with program_guard(program):
......
......@@ -232,6 +232,46 @@ class TestResnet(TestParallelExecutorBase):
for loss in zip(all_reduce_last_loss, reduce_last_loss):
self.assertAlmostEquals(loss[0], loss[1], delta=delta2)
if not use_cuda:
return
all_reduce_first_loss_seq, all_reduce_last_loss_seq = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
iter=iter,
batch_size=batch_size,
use_cuda=use_cuda,
use_reduce=False,
optimizer=optimizer,
enable_sequential_execution=True)
reduce_first_loss_seq, reduce_last_loss_seq = self.check_network_convergence(
model,
feed_dict={"image": img,
"label": label},
iter=iter,
batch_size=batch_size,
use_cuda=use_cuda,
use_reduce=True,
optimizer=optimizer,
enable_sequential_execution=True)
for loss in zip(all_reduce_first_loss, all_reduce_first_loss_seq):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(all_reduce_last_loss, all_reduce_last_loss_seq):
self.assertAlmostEquals(loss[0], loss[1], delta=delta2)
for loss in zip(reduce_first_loss, reduce_first_loss_seq):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(reduce_last_loss, reduce_last_loss_seq):
self.assertAlmostEquals(loss[0], loss[1], delta=delta2)
for loss in zip(all_reduce_first_loss_seq, reduce_first_loss_seq):
self.assertAlmostEquals(loss[0], loss[1], delta=1e-6)
for loss in zip(all_reduce_last_loss_seq, reduce_last_loss_seq):
self.assertAlmostEquals(loss[0], loss[1], delta=delta2)
def _check_resnet_convergence(self,
model,
use_cuda=True,
......
......@@ -173,6 +173,8 @@ class TestTransformer(TestParallelExecutorBase):
def test_main(self):
if core.is_compiled_with_cuda():
self.check_network_convergence(transformer, use_cuda=True)
self.check_network_convergence(
transformer, use_cuda=True, enable_sequential_execution=True)
self.check_network_convergence(transformer, use_cuda=False, iter=5)
......
......@@ -26,7 +26,8 @@ def max_pool2D_forward_naive(x,
strides,
paddings,
global_pool=0,
ceil_mode=False):
ceil_mode=False,
exclusive=True):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
......@@ -54,7 +55,8 @@ def avg_pool2D_forward_naive(x,
strides,
paddings,
global_pool=0,
ceil_mode=False):
ceil_mode=False,
exclusive=True):
N, C, H, W = x.shape
if global_pool == 1:
ksize = [H, W]
......@@ -73,8 +75,9 @@ def avg_pool2D_forward_naive(x,
c_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, r_start:r_end, c_start:c_end]
out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / (
(r_end - r_start) * (c_end - c_start))
field_size = ((r_end - r_start) * (c_end - c_start)) if exclusive \
else (ksize[0] * ksize[1])
out[:, :, i, j] = np.sum(x_masked, axis=(2, 3)) / field_size
return out
......@@ -89,12 +92,13 @@ class TestPool2d_Op(OpTest):
self.init_kernel_type()
self.init_pool_type()
self.init_ceil_mode()
self.init_exclusive()
if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype(self.dtype)
output = self.pool2D_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool,
self.ceil_mode).astype(self.dtype)
output = self.pool2D_forward_naive(
input, self.ksize, self.strides, self.paddings, self.global_pool,
self.ceil_mode, self.exclusive).astype(self.dtype)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)}
self.attrs = {
......@@ -106,7 +110,9 @@ class TestPool2d_Op(OpTest):
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
'ceil_mode': self.ceil_mode,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
'data_format':
'AnyLayout', # TODO(dzhwinter) : should be fix latter
'exclusive': self.exclusive
}
self.outputs = {'Out': output}
......@@ -150,6 +156,9 @@ class TestPool2d_Op(OpTest):
def init_ceil_mode(self):
self.ceil_mode = False
def init_exclusive(self):
self.exclusive = True
class TestCase1(TestPool2d_Op):
def init_test_case(self):
......@@ -322,5 +331,15 @@ class TestCeilModeCase4(TestCase2):
self.ceil_mode = True
class TestAvgInclude(TestCase2):
def init_exclusive(self):
self.exclusive = False
class TestCUDNNAvgInclude(TestCUDNNCase3):
def init_exclusive(self):
self.exclusive = False
if __name__ == '__main__':
unittest.main()
......@@ -26,7 +26,8 @@ def max_pool3D_forward_naive(x,
strides,
paddings,
global_pool=0,
ceil_mode=False):
ceil_mode=False,
exclusive=True):
N, C, D, H, W = x.shape
if global_pool == 1:
ksize = [D, H, W]
......@@ -60,7 +61,8 @@ def avg_pool3D_forward_naive(x,
strides,
paddings,
global_pool=0,
ceil_mode=False):
ceil_mode=False,
exclusive=True):
N, C, D, H, W = x.shape
if global_pool == 1:
ksize = [D, H, W]
......@@ -85,8 +87,10 @@ def avg_pool3D_forward_naive(x,
w_end = np.min((j * strides[1] + ksize[1] - paddings[1], W))
x_masked = x[:, :, d_start:d_end, h_start:h_end, w_start:w_end]
out[:, :, k, i, j] = np.sum(x_masked, axis=(2, 3, 4)) / (
(d_end - d_start) * (h_end - h_start) * (w_end - w_start))
field_size = (d_end - d_start) * (h_end - h_start) * (w_end - w_start) \
if exclusive else ksize[0] * ksize[1] * ksize[2]
out[:, :, k, i, j] = np.sum(x_masked, axis=(2, 3,
4)) / field_size
return out
......@@ -100,13 +104,14 @@ class TestPool3d_Op(OpTest):
self.init_kernel_type()
self.init_pool_type()
self.init_ceil_mode()
self.init_exclusive()
if self.global_pool:
self.paddings = [0 for _ in range(len(self.paddings))]
input = np.random.random(self.shape).astype(self.dtype)
output = self.pool3D_forward_naive(input, self.ksize, self.strides,
self.paddings, self.global_pool,
self.ceil_mode).astype(self.dtype)
output = self.pool3D_forward_naive(
input, self.ksize, self.strides, self.paddings, self.global_pool,
self.ceil_mode, self.exclusive).astype(self.dtype)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(input)}
self.attrs = {
......@@ -117,7 +122,9 @@ class TestPool3d_Op(OpTest):
'global_pooling': self.global_pool,
'use_cudnn': self.use_cudnn,
'ceil_mode': self.ceil_mode,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
'data_format':
'AnyLayout', # TODO(dzhwinter) : should be fix latter
'exclusive': self.exclusive
}
self.outputs = {'Out': output}
......@@ -161,6 +168,9 @@ class TestPool3d_Op(OpTest):
def init_ceil_mode(self):
self.ceil_mode = False
def init_exclusive(self):
self.exclusive = True
class TestCase1(TestPool3d_Op):
def init_test_case(self):
......@@ -333,5 +343,15 @@ class TestCeilModeCase4(TestCase2):
self.ceil_mode = True
class TestAvgInclude(TestCase2):
def init_exclusive(self):
self.exclusive = False
class TestCUDNNAvgInclude(TestCUDNNCase3):
def init_exclusive(self):
self.exclusive = False
if __name__ == '__main__':
unittest.main()
......@@ -26,7 +26,11 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
Test softmax with cross entropy operator with discreate one-hot labels.
"""
def initParams(self):
self.numeric_stable_mode = False
def setUp(self):
self.initParams()
self.op_type = "softmax_with_cross_entropy"
batch_size = 41
class_num = 37
......@@ -46,6 +50,7 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
"Softmax": softmax.astype("float64"),
"Loss": cross_entropy.astype("float64")
}
self.attrs = {"numeric_stable_mode": self.numeric_stable_mode}
def test_check_output(self):
self.check_output()
......@@ -54,6 +59,11 @@ class TestSoftmaxWithCrossEntropyOp(OpTest):
self.check_grad(["Logits"], "Loss")
class TestSoftmaxWithCrossEntropyOpNoCudnn(TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.numeric_stable_mode = True
class TestSoftmaxWithCrossEntropyOp2(OpTest):
"""
Test softmax with cross entropy operator with soft labels.
......@@ -93,7 +103,11 @@ class TestSoftmaxWithCrossEntropyOp3(OpTest):
Test softmax with cross entropy operator with ignore_index.
"""
def initParams(self):
self.numeric_stable_mode = False
def setUp(self):
self.initParams()
self.op_type = "softmax_with_cross_entropy"
batch_size = 41
class_num = 37
......@@ -114,7 +128,10 @@ class TestSoftmaxWithCrossEntropyOp3(OpTest):
"Softmax": softmax.astype("float64"),
"Loss": cross_entropy.astype("float64")
}
self.attrs = {"ignore_index": ignore_index}
self.attrs = {
"ignore_index": ignore_index,
"numeric_stable_mode": self.numeric_stable_mode
}
def test_check_output(self):
self.check_output()
......@@ -123,5 +140,10 @@ class TestSoftmaxWithCrossEntropyOp3(OpTest):
self.check_grad(["Logits"], "Loss")
class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3):
def initParams(self):
self.numeric_stable_mode = True
if __name__ == "__main__":
unittest.main()
......@@ -27,7 +27,7 @@ def _get_version_detail(idx):
if re.match('@TAG_VERSION_REGEX@', '@PADDLE_VERSION@'):
version_details = '@PADDLE_VERSION@'.split('.')
if len(version_details) == 3:
if len(version_details) >= 3:
return version_details[idx]
return 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册