提交 6ce25c99 编写于 作者: L luotao1

Merge branch 'develop' into runtime_context

......@@ -337,7 +337,6 @@ bool NodeCanReused(const VarDesc& node) {
auto type = node.GetType();
// only these types holds bulk of gpu memory
if (!(type == proto::VarType::LOD_TENSOR ||
type == proto::VarType::SELECTED_ROWS ||
type == proto::VarType::LOD_TENSOR_ARRAY)) {
return false;
}
......
......@@ -46,6 +46,7 @@ cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base)
pass_library(cpu_quantize_squash_pass inference)
pass_library(fc_fuse_pass inference)
pass_library(attention_lstm_fuse_pass inference)
pass_library(infer_clean_graph_pass inference)
......@@ -101,6 +102,7 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_cpu_quantize_squash_pass SRCS cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
if (WITH_MKLDNN)
cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)
cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file eint8_outcept in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either eint8_outpress or
// implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/cpu_quantize_squash_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void CPUQuantizeSquashPass::FindNodesToKeep(
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
GraphPatternDetector gpd;
patterns::DequantAny deq_any_pattern{gpd.mutable_pattern(), "deqant_any"};
deq_any_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, deq_any_pattern);
if (nodes_keep_counter->find(dequant_out) == nodes_keep_counter->end())
(*nodes_keep_counter)[dequant_out] = 1;
else
(*nodes_keep_counter)[dequant_out] += 1;
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
}
void CPUQuantizeSquashPass::Squash(
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
GraphPatternDetector gpd;
patterns::DequantQuantAny squash_pattern{gpd.mutable_pattern(), "squash"};
squash_pattern();
int found_squash_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "squash requantize-quantize ops pair";
GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, squash_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, squash_pattern);
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, squash_pattern);
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, squash_pattern);
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern);
auto* next_op_desc = next_op->Op();
float dequant_scale = boost::get<float>(dequant_op->Op()->GetAttr("Scale"));
float quant_scale = boost::get<float>(quant_op->Op()->GetAttr("Scale"));
PADDLE_ENFORCE(nodes_keep_counter->find(dequant_out) !=
nodes_keep_counter->end());
// check if dequantize op should be kept or removed, decrease the counter
bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1;
if (dequant_scale == quant_scale) {
// squash dequantize-quantize to nothing
auto quant_out_var_name = quant_out->Name();
auto next_op_inputs = next_op_desc->InputNames();
for (const auto& name : next_op_inputs) {
auto var_name = next_op_desc->Input(name)[0];
if (var_name.compare(quant_out_var_name) == 0) {
next_op_desc->SetInput(
name, std::vector<std::string>({dequant_in->Name()}));
break;
}
}
if (keep_dequant)
GraphSafeRemoveNodes(graph, {quant_op, quant_out});
else
GraphSafeRemoveNodes(graph,
{dequant_op, quant_op, dequant_out, quant_out});
IR_NODE_LINK_TO(dequant_in, next_op);
found_squash_count++;
} else {
// squash dequantize-quantize to requantize op
OpDesc desc;
desc.SetType("requantize");
desc.SetInput("Input", std::vector<std::string>({dequant_in->Name()}));
desc.SetOutput("Output", std::vector<std::string>({quant_out->Name()}));
desc.SetAttr("Scale_in", dequant_scale);
desc.SetAttr("Scale_out", quant_scale);
auto requant_op = g->CreateOpNode(&desc);
if (keep_dequant)
GraphSafeRemoveNodes(graph, {quant_op});
else
GraphSafeRemoveNodes(graph, {dequant_op, quant_op, dequant_out});
IR_NODE_LINK_TO(dequant_in, requant_op);
IR_NODE_LINK_TO(requant_op, quant_out);
found_squash_count++;
}
};
gpd(graph, handler);
AddStatis(found_squash_count);
PrettyLogDetail("--- squashed %d dequantize-quantize pairs",
found_squash_count);
}
std::unique_ptr<ir::Graph> CPUQuantizeSquashPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
FusePassBase::Init("cpu_quantize_squash_pass", graph.get());
std::unordered_map<const Node*, int> nodes_keep_counter;
FindNodesToKeep(graph.get(), &nodes_keep_counter);
Squash(graph.get(), &nodes_keep_counter);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(cpu_quantize_squash_pass,
paddle::framework::ir::CPUQuantizeSquashPass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Squash dequantize->quantize pair pattern into requantize op
*/
class CPUQuantizeSquashPass : public FusePassBase {
public:
virtual ~CPUQuantizeSquashPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
/*
* For each dequantize's output find the number of operators it is an input to
*/
void FindNodesToKeep(
Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
/*
* Squash dequantize-quantize ops pairs into requantize or nothing
*/
void Squash(Graph* graph,
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
const std::string name_scope_{"squash"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/cpu_quantize_squash_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, bool use_mkldnn,
float scale = 0) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("name", name);
if (type == "conv2d") {
op->SetInput("Input", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
op->SetOutput("Output", {outputs[0]});
} else if (type == "quantize") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale", scale);
} else if (type == "dequantize") {
op->SetInput("Input", {inputs[0]});
op->SetOutput("Output", {outputs[0]});
op->SetAttr("Scale", scale);
}
}
// (a,w1,b1)->Conv1->d
// d->Dequant->e
// e->Quant->f
// (f,w2,b2)->Conv2->i
ProgramDesc BuildProgramDesc(bool use_mkldnn, float scale1, float scale2) {
ProgramDesc prog;
for (auto& v : std::initializer_list<std::string>(
{"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) {
auto* var = prog.MutableBlock(0)->Var(v);
if (v.find("w") == 0 || v.find("b") == 0) {
var->SetPersistable(true);
}
}
SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn);
SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, scale1);
SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, scale2);
SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn);
return prog;
}
static const std::initializer_list<std::string> variable_names{
"a", "b", "c", "d", "e", "f", "g", "h"};
// a->Conv1->b
// b->Dequant->c
//
// c->Quant1->d and d->Conv2->e
//
// c->Conv3->f
//
// c->Quant2->g and g->Conv4->h
//
ProgramDesc BuildProgramDesc2(bool use_mkldnn, float scale1, float scale2,
float scale3) {
ProgramDesc prog;
for (auto& v : variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn);
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1);
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, scale2);
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn);
SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn);
SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, scale3);
SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn);
return prog;
}
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
const char* var_name) {
auto x = scope->Var(var_name);
auto tensor = x->GetMutable<LoDTensor>();
tensor->mutable_data(place, proto::VarType::FP32,
::paddle::memory::Allocator::kDefault, 1);
}
void MainTest(const ProgramDesc& prog, int removed_nodes_num) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
// Init scope, as it is used in pass
auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place};
Scope scope;
exe.CreateVariables(prog, 0, true, &scope);
for (auto& v : variable_names) {
InitTensorHolder(&scope, place, v.c_str());
}
graph->Set(kParamScopeAttr, new framework::Scope*(&scope));
auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass");
int original_nodes_num = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int current_nodes_num = graph->Nodes().size();
EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num);
}
TEST(CpuQuantizeSquashPass, equal_scales) {
auto scale = 1.2345f;
auto use_mkldnn = true;
// Remove 4 nodes: Dequant, Quant, e, f
auto remove_nodes = 4;
MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes);
use_mkldnn = !use_mkldnn;
MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes);
}
TEST(CpuQuantizeSquashPass, inequal_scales) {
auto scale1 = 1.2345f;
auto scale2 = 21.0f;
auto use_mkldnn = true;
// Remove 3 nodes: Dequant, Quant, e
// Insert 1 node: requantize
auto remove_nodes = 2;
MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes);
use_mkldnn = !use_mkldnn;
MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes);
}
TEST(CpuQuantizeSquashPass, branch_to_equal_inequal_and_fp32) {
// Delete both quantize ops,
// bypass dequantize in both branches,
// insert requantize on one branch
auto scale = 1.2345f;
auto scale2 = 21.0f;
auto use_mkldnn = true;
// Remove 3 nodes: Quant1, Quant2, g
// Insert 1 node: requantize
auto remove_nodes = 2;
MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes);
use_mkldnn = !use_mkldnn;
MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(cpu_quantize_squash_pass);
......@@ -1301,6 +1301,51 @@ PDNode *patterns::ConvAffineChannel::operator()(
return ac_out_var;
}
PDNode *patterns::DequantQuantAny::operator()() {
auto *dequant_in = pattern->NewNode(dequant_in_repr())
->AsInput()
->assert_is_op_input("dequantize", "Input");
auto *dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto *dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");
auto *quant_op = pattern->NewNode(quant_op_repr())
->assert_is_op("quantize")
->AsIntermediate();
auto *quant_out = pattern->NewNode(quant_out_repr())
->AsOutput()
->assert_is_op_output("quantize");
auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op();
dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out});
quant_op->LinksFrom({dequant_out}).LinksTo({quant_out});
next_op->LinksFrom({quant_out});
return quant_out;
}
PDNode *patterns::DequantAny::operator()() {
auto *dequant_op =
pattern->NewNode(dequant_op_repr())->assert_is_op("dequantize");
auto *dequant_out = pattern->NewNode(dequant_out_repr())
->AsOutput()
->assert_is_op_output("dequantize", "Output");
auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op();
dequant_op->LinksTo({dequant_out});
next_op->LinksFrom({dequant_out});
return dequant_out;
}
// a -> transpose_op(1) -> transpose_out_a -> flatten_op(1) -> flatten_out_a
// b -> transpose_op(2) -> transpose_out_b -> flatten_op(2) -> flatten_out_b
// ...
......
......@@ -18,8 +18,11 @@
#include <gtest/gtest_prod.h>
#endif
#include <memory>
#include <numeric>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
......@@ -766,6 +769,34 @@ struct ConvAffineChannel : public PatternBase {
PATTERN_DECL_NODE(ac_out); // Out
};
// Dequantize + Quantize + anyOP
// This pattern is used for squashing the dequantize-quantize pairs.
struct DequantQuantAny : public PatternBase {
DequantQuantAny(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "dequant_quant_any") {}
PDNode* operator()();
PATTERN_DECL_NODE(dequant_in);
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
PATTERN_DECL_NODE(quant_op);
PATTERN_DECL_NODE(quant_out);
PATTERN_DECL_NODE(next_op);
};
// Dequantize + anyOP
// This quantize is used for getting number of ops the Dequantize's
// output is an input to.
struct DequantAny : public PatternBase {
DequantAny(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "dequant_any") {}
PDNode* operator()();
PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);
PATTERN_DECL_NODE(next_op);
};
struct TransposeFlattenConcat : public PatternBase {
TransposeFlattenConcat(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "transpose_flatten_concat") {}
......
......@@ -13,18 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cross_entropy_op.h"
#include <memory>
#include <string>
#include <unordered_map>
namespace paddle {
namespace operators {
class CrossEntropyOp : public framework::OperatorWithKernel {
class CrossEntropyOpBase : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
auto x_dims = ctx->GetInputDim("X");
......@@ -43,7 +46,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
if (ctx->Attrs().Get<bool>("soft_label")) {
if (IsSoftLabel(ctx)) {
if (check) {
PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of "
......@@ -69,21 +73,24 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
}
virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const {
return ctx->Attrs().Get<bool>("soft_label");
}
};
class CrossEntropyGradientOp : public framework::OperatorWithKernel {
class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
void InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) shoudl be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto x_dims = GetXDim(ctx);
auto label_dims = ctx->GetInputDim("Label");
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
int rank = x_dims.size();
......@@ -108,9 +115,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
"The Input(X) and Input(Y@Grad) should have the same "
"shape except the last dimension.");
}
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
"The last dimension of Input(Y@Grad) should be 1.");
if (ctx->Attrs().Get<bool>("soft_label")) {
if (IsSoftLabel(ctx)) {
if (check) {
PADDLE_ENFORCE_EQ(
x_dims[rank - 1], label_dims[rank - 1],
......@@ -123,7 +128,10 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
"Input(Label) should be 1.");
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("X", framework::GradVarName("X"));
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
"The last dimension of Input(Y@Grad) should be 1.");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD(VarNameWithXLoD(), framework::GradVarName("X"));
}
protected:
......@@ -131,8 +139,28 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Y"))->type(),
ctx.device_context());
}
virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const {
return ctx->GetInputDim("X");
}
virtual const char* VarNameWithXLoD() const { return "X"; }
virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const {
return ctx->Attrs().Get<bool>("soft_label");
}
};
class CrossEntropyOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
}
};
......@@ -200,22 +228,132 @@ or not. But the output only shares the LoD information with input X.
}
};
class CrossEntropyOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
class CrossEntropyGradientOp : public CrossEntropyGradientOpBase {
public:
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
CrossEntropyGradientOpBase::InferShape(ctx);
}
};
class CrossEntropyOp2 : public CrossEntropyOpBase {
public:
using CrossEntropyOpBase::CrossEntropyOpBase;
void InferShape(framework::InferShapeContext* ctx) const override {
CrossEntropyOpBase::InferShape(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("MatchX"),
"Output(MatchX) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto x_dims_vec = framework::vectorize(x_dims);
x_dims_vec.push_back(0);
ctx->SetOutputDim("XShape", framework::make_ddim(x_dims_vec));
x_dims[x_dims.size() - 1] = 1;
ctx->SetOutputDim("MatchX", x_dims);
ctx->ShareLoD("X", /*->*/ "XShape");
}
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Y"}};
bool IsSoftLabel(framework::InferShapeContext* ctx) const override {
return false;
}
};
class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase {
public:
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("MatchX"), "Input(MatchX) must exist");
CrossEntropyGradientOpBase::InferShape(ctx);
}
protected:
virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const {
auto x_shape = ctx->GetInputDim("XShape");
return framework::DDim(x_shape.Get(), x_shape.size() - 1);
}
virtual const char* VarNameWithXLoD() const { return "XShape"; }
virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const {
return false;
}
};
class CrossEntropyOpMaker2 : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), a tensor whose last dimension "
"size is equal to the number of classes. This input is a "
"probability computed by the previous operator, which is almost "
"always the result of a softmax operator.");
AddInput(
"Label",
"(Tensor), the tensor which represents the ground truth. It has the "
"same shape with 'X' except the last dimension. One hot Tensor.");
AddOutput("Y",
"(Tensor, default Tensor<float>), a tensor whose shape is same "
"with 'X' except that the last dimension size is 1. It "
"represents the cross entropy loss.");
AddOutput("XShape", "Temporaily variable to save shape and LoD of X.");
AddOutput("MatchX",
"X value that matches label, used for gradient computation.");
AddAttr<int>("ignore_index",
"(int, default -100), Specifies a target value that is"
"ignored and does not contribute to the input gradient."
"Only valid if soft_label is set to False")
.SetDefault(-100);
AddComment(R"DOC(
Hard-label CrossEntropy Operator.
The input 'X' and 'Label' will first be logically flattened to 2-D matrixs.
The matrix's second dimension(row length) is as same as the original last
dimension, and the first dimension(column length) is the product of all other
original dimensions. Then the softmax computation will take palce on each raw
of flattened matrixs.
Only support hard label.
Both the input X and Label can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD information with input X.
)DOC");
}
};
class CrossEntropyGradOpDescMaker2 : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("cross_entropy_grad2");
op->SetInput("Label", Input("Label"));
op->SetInput("MatchX", Output("MatchX"));
op->SetInput("XShape", Output("XShape"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPUCtx = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOp, ops::CrossEntropyOpMaker,
ops::CrossEntropyOpInferVarType,
REGISTER_OPERATOR(cross_entropy, ops::CrossEntropyOpBase,
ops::CrossEntropyOpMaker, ops::CrossEntropyOpInferVarType,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(cross_entropy_grad, ops::CrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>,
......@@ -223,3 +361,14 @@ REGISTER_OP_CPU_KERNEL(cross_entropy, ops::CrossEntropyOpKernel<CPUCtx, float>,
REGISTER_OP_CPU_KERNEL(cross_entropy_grad,
ops::CrossEntropyGradientOpKernel<CPUCtx, float>,
ops::CrossEntropyGradientOpKernel<CPUCtx, double>);
REGISTER_OPERATOR(cross_entropy2, ops::CrossEntropyOp2,
ops::CrossEntropyOpMaker2, ops::CrossEntropyOpInferVarType,
ops::CrossEntropyGradOpDescMaker2);
REGISTER_OPERATOR(cross_entropy_grad2, ops::CrossEntropyGradientOp2);
REGISTER_OP_CPU_KERNEL(cross_entropy2,
ops::CrossEntropyOpKernel2<CPUCtx, float>,
ops::CrossEntropyOpKernel2<CPUCtx, double>);
REGISTER_OP_CPU_KERNEL(cross_entropy_grad2,
ops::CrossEntropyGradientOpKernel2<CPUCtx, float>,
ops::CrossEntropyGradientOpKernel2<CPUCtx, double>);
......@@ -27,3 +27,13 @@ REGISTER_OP_CUDA_KERNEL(
cross_entropy_grad, ops::CrossEntropyGradientOpKernel<CUDACtx, float>,
ops::CrossEntropyGradientOpKernel<CUDACtx, double>,
ops::CrossEntropyGradientOpKernel<CUDACtx, plat::float16>);
REGISTER_OP_CUDA_KERNEL(cross_entropy2,
ops::CrossEntropyOpKernel2<CUDACtx, float>,
ops::CrossEntropyOpKernel2<CUDACtx, double>,
ops::CrossEntropyOpKernel2<CUDACtx, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
cross_entropy_grad2, ops::CrossEntropyGradientOpKernel2<CUDACtx, float>,
ops::CrossEntropyGradientOpKernel2<CUDACtx, double>,
ops::CrossEntropyGradientOpKernel2<CUDACtx, plat::float16>);
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
......@@ -137,5 +138,124 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
}
};
template <typename T>
struct HardLabelCrossEntropyForwardFunctor {
HardLabelCrossEntropyForwardFunctor(const T* x, T* y, T* match_x,
const int64_t* label,
int64_t ignore_index,
int64_t feature_size)
: x_(x),
y_(y),
match_x_(match_x),
label_(label),
ignore_index_(ignore_index),
feature_size_(feature_size) {}
HOSTDEVICE void operator()(int64_t idx) const {
auto label = label_[idx];
if (label != ignore_index_) {
auto match_x = x_[idx * feature_size_ + label];
y_[idx] = -math::TolerableValue<T>()(real_log(match_x));
match_x_[idx] = match_x;
} else {
y_[idx] = 0;
match_x_[idx] = 0; // any value is ok
}
}
const T* x_;
T* y_;
T* match_x_;
const int64_t* label_;
int64_t ignore_index_;
int64_t feature_size_;
};
template <typename T>
struct HardLabelCrossEntropyBackwardFunctor {
HardLabelCrossEntropyBackwardFunctor(T* dx, const T* dy, const T* match_x,
const int64_t* label,
int64_t ignore_index,
int64_t feature_size)
: dx_(dx),
dy_(dy),
match_x_(match_x),
label_(label),
ignore_index_(ignore_index),
feature_size_(feature_size) {}
HOSTDEVICE void operator()(int64_t idx) const {
auto row_idx = idx / feature_size_;
auto col_idx = idx % feature_size_;
auto label = label_[row_idx];
if (label == col_idx && label != ignore_index_) {
dx_[idx] = -dy_[row_idx] / match_x_[row_idx];
} else {
dx_[idx] = 0;
}
}
T* dx_;
const T* dy_;
const T* match_x_;
const int64_t* label_;
int64_t ignore_index_;
int64_t feature_size_;
};
template <typename DeviceContext, typename T>
class CrossEntropyOpKernel2 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* label = ctx.Input<Tensor>("Label");
auto* y = ctx.Output<Tensor>("Y");
auto* match_x = ctx.Output<Tensor>("MatchX");
auto& x_dims = x->dims();
auto feature_size = x_dims[x_dims.size() - 1];
auto batch_size = framework::product(x->dims()) / feature_size;
auto* p_x = x->data<T>();
auto* p_label = label->data<int64_t>();
auto* p_y = y->mutable_data<T>(ctx.GetPlace());
auto* p_match_x = match_x->mutable_data<T>(ctx.GetPlace());
auto ignore_index = ctx.Attr<int>("ignore_index");
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), batch_size);
for_range(HardLabelCrossEntropyForwardFunctor<T>(
p_x, p_y, p_match_x, p_label, ignore_index, feature_size));
}
};
template <typename DeviceContext, typename T>
class CrossEntropyGradientOpKernel2 : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* match_x = ctx.Input<Tensor>("MatchX");
auto* label = ctx.Input<Tensor>("Label");
auto* p_dx = dx->mutable_data<T>(ctx.GetPlace());
auto* p_dy = dy->data<T>();
auto* p_match_x = match_x->data<T>();
auto* p_label = label->data<int64_t>();
int64_t ignore_index = ctx.Attr<int>("ignore_index");
int rank = dx->dims().size();
int64_t feature_size = dx->dims()[rank - 1];
int64_t batch_size = framework::product(dx->dims()) / feature_size;
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(),
batch_size * feature_size);
for_range(HardLabelCrossEntropyBackwardFunctor<T>(
p_dx, p_dy, p_match_x, p_label, ignore_index, feature_size));
}
};
} // namespace operators
} // namespace paddle
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/expand_op.h"
#include <memory>
#include <vector>
namespace paddle {
......@@ -138,12 +139,28 @@ class ExpandGradOp : public framework::OperatorWithKernel {
}
};
class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("expand_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::ExpandGradOpDescMaker);
REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp);
REGISTER_OP_CPU_KERNEL(
expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>,
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "math.h" // NOLINT
namespace paddle {
namespace operators {
inline HOSTDEVICE platform::float16 real_exp(platform::float16 x) {
return static_cast<platform::float16>(::expf(static_cast<float>(x)));
}
inline HOSTDEVICE float real_exp(float x) { return ::expf(x); }
inline HOSTDEVICE double real_exp(double x) { return ::exp(x); }
inline HOSTDEVICE platform::float16 real_log(platform::float16 x) {
return static_cast<platform::float16>(::logf(static_cast<float>(x)));
}
inline HOSTDEVICE float real_log(float x) { return ::logf(x); }
inline HOSTDEVICE double real_log(double x) { return ::log(x); }
} // namespace operators
} // namespace paddle
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/math/cross_entropy.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
......@@ -20,17 +21,6 @@ namespace paddle {
namespace operators {
namespace math {
namespace {
__device__ __forceinline__ float real_log(float x) { return logf(x); }
__device__ __forceinline__ double real_log(double x) { return log(x); }
__device__ __forceinline__ platform::float16 real_log(
const platform::float16& val) {
return static_cast<platform::float16>(logf(static_cast<float>(val)));
}
template <typename T>
__global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label,
const int N, const int D,
......@@ -61,7 +51,6 @@ __global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
Y[blockIdx.x] = -val;
}
}
} // namespace
template <typename T>
class CrossEntropyFunctor<platform::CUDADeviceContext, T> {
......
......@@ -219,14 +219,6 @@ class ReshapeKernel {
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
out_dims = ReshapeOp::ValidateShape(shape, in->dims());
}
if (!in->lod().empty()) {
PADDLE_ENFORCE_EQ(
out_dims[0], in->dims()[0],
"Reshape operator cannot reshape an input sequence batch "
"into an output sequence batch that has a different "
"number of time steps. Please consider using "
"sequence_reshape op.");
}
out->mutable_data(ctx.GetPlace(), in->type());
framework::TensorCopy(
......
......@@ -15,13 +15,12 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
static HOSTDEVICE float real_exp(float x) { return expf(x); }
static HOSTDEVICE float real_exp(double x) { return exp(x); }
template <typename T>
struct SeluFunctor {
SeluFunctor(const T* x_data_ptr, float alpha, float scale, T* y_data_ptr)
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <algorithm>
#include <cub/cub.cuh> // NOLINT
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/sequence_ops/sequence_softmax_op.h"
namespace paddle {
......@@ -21,9 +22,6 @@ namespace operators {
using LoDTensor = framework::LoDTensor;
__device__ __forceinline__ float real_exp(float x) { return expf(x); }
__device__ __forceinline__ double real_exp(double x) { return exp(x); }
template <typename T, int BlockDim>
using BlockReduce = cub::BlockReduce<T, BlockDim>;
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cub/cub.cuh"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/hostdevice.h"
......@@ -21,11 +22,6 @@ namespace operators {
using Tensor = framework::Tensor;
static HOSTDEVICE float real_exp(float x) { return expf(x); }
static HOSTDEVICE float real_exp(double x) { return exp(x); }
static HOSTDEVICE float real_log(float x) { return logf(x); }
static HOSTDEVICE float real_log(double x) { return log(x); }
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
......
......@@ -12,18 +12,138 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/slice_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
template <size_t D>
__global__ void Padding(const paddle::platform::float16* d_out,
const int* out_dims, const int* in_dims,
const int* offsets, int64_t n,
paddle::platform::float16* d_in) {
int64_t out_idx = threadIdx.x + blockDim.x * blockIdx.x;
if (out_idx < n) {
int coords[D] = {0};
for (int i = D - 1; i >= 0; --i) {
coords[i] = out_idx % out_dims[i];
out_idx /= out_dims[i];
coords[i] += offsets[i];
}
int64_t in_idx = 0;
for (int i = 0; i < D - 1; ++i) {
in_idx += coords[i] * in_dims[i + 1];
}
in_idx += coords[D - 1];
d_in[in_idx] = d_out[out_idx];
}
}
template <>
class SliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>
: public framework::OpKernel<paddle::platform::float16> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_in = ctx.Output<framework::Tensor>(framework::GradVarName("Input"));
d_in->mutable_data<paddle::platform::float16>(ctx.GetPlace());
auto out_dims = d_out->dims();
auto in_dims = d_in->dims();
int rank = out_dims.size();
std::vector<int> offsets(rank, 0);
auto axes = ctx.Attr<std::vector<int>>("axes");
auto starts = ctx.Attr<std::vector<int>>("starts");
for (size_t i = 0; i < starts.size(); ++i) {
if (starts[i] < 0) {
starts[i] += in_dims[axes[i]];
}
offsets[axes[i]] = std::max(starts[i], 0);
}
math::SetConstant<paddle::platform::CUDADeviceContext,
paddle::platform::float16>
set_zero;
auto& dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
set_zero(dev_ctx, d_in, static_cast<paddle::platform::float16>(0));
int64_t numel = d_out->numel();
dim3 blocks((numel - 1) / PADDLE_CUDA_NUM_THREADS + 1, 1, 1);
dim3 threads(PADDLE_CUDA_NUM_THREADS, 1, 1);
auto stream = ctx.cuda_device_context().stream();
auto out_shape = framework::vectorize2int(out_dims);
thrust::device_vector<int> out_dims_vec(out_shape.begin(), out_shape.end());
auto in_shape = framework::vectorize2int(in_dims);
thrust::device_vector<int> in_dims_vec(in_shape.begin(), in_shape.end());
thrust::device_vector<int> offsets_vec(offsets.begin(), offsets.end());
const int* out_dims_ptr = thrust::raw_pointer_cast(out_dims_vec.data());
const int* in_dims_ptr = thrust::raw_pointer_cast(in_dims_vec.data());
const int* offsets_ptr = thrust::raw_pointer_cast(offsets_vec.data());
switch (rank) {
case 1:
Padding<1><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 2:
Padding<2><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 3:
Padding<3><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 4:
Padding<4><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 5:
Padding<5><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
case 6:
Padding<6><<<blocks, threads, 0, stream>>>(
d_out->data<paddle::platform::float16>(), out_dims_ptr, in_dims_ptr,
offsets_ptr, numel, d_in->data<paddle::platform::float16>());
break;
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
slice, ops::SliceKernel<paddle::platform::CUDADeviceContext, float>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
slice_grad,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, plat::float16>);
......@@ -1432,6 +1432,8 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
predict = fluid.layers.fc(input=net, size=classdim, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
"""
if not soft_label:
return cross_entropy2(input, label, ignore_index)
helper = LayerHelper('cross_entropy', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
......@@ -1444,6 +1446,22 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
return out
def cross_entropy2(input, label, ignore_index=kIgnoreIndex):
helper = LayerHelper('cross_entropy2', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
xshape = helper.create_variable_for_type_inference(dtype=input.dtype)
match_x = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='cross_entropy2',
inputs={'X': [input],
'Label': [label]},
outputs={'Y': [out],
'MatchX': [match_x],
'XShape': [xshape]},
attrs={'ignore_index': ignore_index})
return out
def bpr_loss(input, label, name=None):
"""
Bayesian Personalized Ranking Loss Operator.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from op_test import OpTest
import unittest
import numpy as np
import six
class CrossEntropy2OpTestBase(OpTest):
def initParameters(self):
return [32, 64], 'float32', -100
def calc_output(self, logits, label, ignore_index):
ret = np.zeros(shape=label.shape, dtype=logits.dtype)
match_x = np.zeros(shape=label.shape, dtype=logits.dtype)
for idx in six.moves.range(label.shape[0]):
if label[idx] == ignore_index:
continue
match_x[idx] = logits[idx][label[idx]]
ret[idx] = -np.log(match_x[idx])
return ret, match_x
def setUp(self):
self.shape, self.dtype, self.ignore_index = self.initParameters()
self.op_type = 'cross_entropy2'
feature_size = int(self.shape[-1])
batch_size = int(np.prod(self.shape) / feature_size)
logits = (np.random.random(size=self.shape) + 1).astype(self.dtype)
label = np.random.random_integers(
low=0, high=feature_size - 1,
size=self.shape[0:-1] + [1]).astype('int64')
outputs, match_x = self.calc_output(
np.reshape(logits, [batch_size, feature_size]),
np.reshape(label, [batch_size, 1]), self.ignore_index)
self.inputs = {'X': logits, 'Label': label}
self.outputs = {
'Y': np.reshape(outputs, label.shape),
'MatchX': np.reshape(match_x, label.shape),
'XShape': np.zeros(
shape=logits.shape, dtype=logits.dtype)
}
self.attrs = {'ignore_index': self.ignore_index}
def test_check_output(self):
self.check_output(no_check_set=['XShape'])
def test_check_grad(self):
self.check_grad(
inputs_to_check=['X'],
output_names=['Y'],
no_grad_set=['XShape', 'MatchX', 'Label'])
class CrossEntropy2OpTest2(CrossEntropy2OpTestBase):
def initParameters(self):
return [32, 64], 'float64', 3
class CrossEntropy2OpTest3(CrossEntropy2OpTestBase):
def initParameters(self):
return [4, 8, 16, 32], 'float32', -100
class CrossEntropy2OpTest4(CrossEntropy2OpTestBase):
def initParameters(self):
return [4, 8, 16, 32], 'float32', 3
if __name__ == '__main__':
unittest.main()
......@@ -524,8 +524,8 @@ class TestLocalLookupTable(TestDistLookupTableBase):
ops = [
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
'cross_entropy2', 'mean', 'fill_constant', 'mean_grad',
'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad',
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'split_selected_rows', 'send', 'sequence_pool_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
......@@ -564,8 +564,8 @@ class TestDistLookupTable(TestDistLookupTableBase):
ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul',
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant',
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send',
'elementwise_add', 'cross_entropy2', 'mean', 'fill_constant',
'mean_grad', 'cross_entropy_grad2', 'elementwise_add_grad', 'send',
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
'lookup_table_grad', 'split_selected_rows', 'send',
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
......@@ -612,8 +612,8 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase):
ops = [
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
'cross_entropy2', 'mean', 'fill_constant', 'mean_grad',
'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad',
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'split_selected_rows', 'send', 'sequence_pool_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
......@@ -652,8 +652,8 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
'sequence_pool', 'lookup_table', 'sequence_pool', 'concat', 'mul',
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant',
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send',
'elementwise_add', 'cross_entropy2', 'mean', 'fill_constant',
'mean_grad', 'cross_entropy_grad2', 'elementwise_add_grad', 'send',
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
'lookup_table_grad', 'split_selected_rows', 'send',
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
......@@ -841,8 +841,8 @@ class TestRemoteLookupTable(TestDistLookupTableBase):
ops = [
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
'lookup_table', 'sequence_pool', 'concat', 'mul', 'elementwise_add',
'cross_entropy', 'mean', 'fill_constant', 'mean_grad',
'cross_entropy_grad', 'elementwise_add_grad', 'send', 'mul_grad',
'cross_entropy2', 'mean', 'fill_constant', 'mean_grad',
'cross_entropy_grad2', 'elementwise_add_grad', 'send', 'mul_grad',
'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
'split_selected_rows', 'send', 'sequence_pool_grad',
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
......@@ -63,5 +64,28 @@ class TestCase2(TestSliceOp):
self.out = self.input[-3:3, 0:100, :, 2:-1]
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFP16(TestSliceOp):
def config(self):
self.dtype = "float16"
self.input = np.random.random([3, 4, 5, 6]).astype(self.dtype)
self.starts = [-3, 0, 2]
self.ends = [3, 100, -1]
self.axes = [0, 1, 3]
self.out = self.input[-3:3, 0:100, :, 2:-1]
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=1e-5)
def test_check_grad_normal(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['Input'], 'Out', max_relative_error=0.006)
if __name__ == '__main__':
unittest.main()
......@@ -160,6 +160,8 @@ class Timeline(object):
self._devices[(k, event.device_id, "GPUKernel")] = pid
self._chrome_trace.emit_pid("%s:gpu:%d" %
(k, event.device_id), pid)
if not hasattr(profile_pb, "mem_events"):
continue
for mevent in profile_pb.mem_events:
if mevent.place == profiler_pb2.MemEvent.CUDAPlace:
if (k, mevent.device_id, "GPU") not in self._mem_devices:
......@@ -211,7 +213,7 @@ class Timeline(object):
args = {'name': event.name}
if event.memcopy.bytes > 0:
args['mem_bytes'] = event.memcopy.bytes
if event.detail_info:
if hasattr(event, "detail_info") and event.detail_info:
args['detail_info'] = event.detail_info
# TODO(panyx0718): Chrome tracing only handles ms. However, some
# ops takes micro-seconds. Hence, we keep the ns here.
......@@ -220,6 +222,8 @@ class Timeline(object):
event.sub_device_id, 'Op', event.name, args)
def _allocate_memory_event(self):
if not hasattr(profiler_pb2, "MemEvent"):
return
place_to_str = {
profiler_pb2.MemEvent.CPUPlace: "CPU",
profiler_pb2.MemEvent.CUDAPlace: "GPU",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册