提交 b2943530 编写于 作者: D Dang Qingqing

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into map_api

...@@ -80,7 +80,6 @@ message OpProto { ...@@ -80,7 +80,6 @@ message OpProto {
optional bool duplicable = 3 [ default = false ]; optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ]; optional bool intermediate = 4 [ default = false ];
optional bool dispensable = 5 [ default = false ]; optional bool dispensable = 5 [ default = false ];
optional string reuse = 6;
} }
// AttrProto describes the C++ type Attribute. // AttrProto describes the C++ type Attribute.
......
...@@ -42,12 +42,10 @@ if(WITH_MKLDNN) ...@@ -42,12 +42,10 @@ if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base) pass_library(mkldnn_placement_pass base)
pass_library(conv_bias_mkldnn_fuse_pass inference) pass_library(conv_bias_mkldnn_fuse_pass inference)
pass_library(conv_relu_mkldnn_fuse_pass inference) pass_library(conv_relu_mkldnn_fuse_pass inference)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference)
endif() endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
if(WITH_MKLDNN)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference)
endif()
set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library")
......
...@@ -200,15 +200,15 @@ TEST(GraphHelperTest, GraphNum) { ...@@ -200,15 +200,15 @@ TEST(GraphHelperTest, GraphNum) {
Graph g(prog); Graph g(prog);
BuildZeroGraph(&g); BuildZeroGraph(&g);
ASSERT_EQ(GraphNum(g), 0); ASSERT_EQ(GraphNum(g), 0UL);
Graph g2(prog); Graph g2(prog);
BuildOneGraph(&g2); BuildOneGraph(&g2);
ASSERT_EQ(GraphNum(g2), 1); ASSERT_EQ(GraphNum(g2), 1UL);
Graph g3(prog); Graph g3(prog);
BuildTwoGraphs(&g3); BuildTwoGraphs(&g3);
ASSERT_EQ(GraphNum(g3), 2); ASSERT_EQ(GraphNum(g3), 2UL);
} }
} // namespace ir } // namespace ir
......
...@@ -124,7 +124,7 @@ TEST(GraphTest, Basic) { ...@@ -124,7 +124,7 @@ TEST(GraphTest, Basic) {
ASSERT_EQ(n->outputs.size(), 0UL); ASSERT_EQ(n->outputs.size(), 0UL);
} }
} }
ASSERT_EQ(nodes.size(), 5); ASSERT_EQ(nodes.size(), 5UL);
} }
TEST(GraphTest, WriteAfterRead) { TEST(GraphTest, WriteAfterRead) {
......
...@@ -515,20 +515,14 @@ void OpDesc::InferShape(const BlockDesc &block) const { ...@@ -515,20 +515,14 @@ void OpDesc::InferShape(const BlockDesc &block) const {
} }
void OpDesc::InferVarType(BlockDesc *block) const { void OpDesc::InferVarType(BlockDesc *block) const {
// There are a few places that var type can be set.
// When VarDesc is created, default set to LOD_TENSOR.
// When output variable is created, default is defaut set to LOD_TENSOR.
// We limit here to be the only place that operator defines its customized
// var type inference. Hence, we don't do any "default" setting here.
auto &info = OpInfoMap::Instance().Get(this->Type()); auto &info = OpInfoMap::Instance().Get(this->Type());
if (info.infer_var_type_) { if (info.infer_var_type_) {
info.infer_var_type_(*this, block); info.infer_var_type_(*this, block);
} else {
// all output type is LoDTensor by default
VLOG(10) << this->Type()
<< " has not registered InferVarType. Set output variables to "
"LOD_TENSOR";
for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(proto::VarType::LOD_TENSOR);
}
}
} }
} }
......
...@@ -21,7 +21,6 @@ namespace framework { ...@@ -21,7 +21,6 @@ namespace framework {
void OpProtoAndCheckerMaker::Validate() { void OpProtoAndCheckerMaker::Validate() {
validated_ = true; validated_ = true;
CheckNoDuplicatedInOutAttrs(); CheckNoDuplicatedInOutAttrs();
CheckReuseVars();
} }
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput( OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput(
...@@ -40,40 +39,6 @@ OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput( ...@@ -40,40 +39,6 @@ OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
return OpProtoAndCheckerMaker::VariableBuilder{output}; return OpProtoAndCheckerMaker::VariableBuilder{output};
} }
void OpProtoAndCheckerMaker::Reuse(const std::string& name,
const std::string& reused_name) {
bool found = false;
proto::OpProto::Var* var;
for (auto& var : proto_->inputs()) {
if (var.name() == reused_name) {
found = true;
break;
}
}
PADDLE_ENFORCE(found == true,
"Input/Output name: %s reused_name: %s, one of them is not "
"exists or not matched.",
name, reused_name);
found = false;
for (int i = 0; i < proto_->outputs().size(); ++i) {
var = proto_->mutable_outputs()->Mutable(i);
if (var->name() == name) {
PADDLE_ENFORCE(!var->has_reuse(),
"Output(%s) has been set reused var of %s", name,
var->reuse());
found = true;
var->set_reuse(reused_name);
break;
}
}
PADDLE_ENFORCE(found == true,
"Input/Output name: %s reused_name: %s, one of them is not "
"exists or not matched.",
name, reused_name);
}
void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names; std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) { auto checker = [&](const std::string& name) {
...@@ -91,24 +56,6 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() { ...@@ -91,24 +56,6 @@ void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
} }
} }
void OpProtoAndCheckerMaker::CheckReuseVars() {
std::unordered_set<std::string> names;
for (auto& input : proto_->inputs()) {
names.insert(input.name());
}
auto checker = [&](const std::string& name, const std::string& reused) {
PADDLE_ENFORCE(
names.count(reused),
"Output [%s] reuse Input [%s], but the input is not registered.", name,
reused);
};
for (auto& output : proto_->outputs()) {
if (output.has_reuse()) {
checker(output.name(), output.reuse());
}
}
}
void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
OpAttrChecker* attr_checker) { OpAttrChecker* attr_checker) {
proto_ = proto; proto_ = proto;
......
...@@ -14,8 +14,6 @@ limitations under the License. */ ...@@ -14,8 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_set>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
...@@ -73,11 +71,6 @@ class OpProtoAndCheckerMaker { ...@@ -73,11 +71,6 @@ class OpProtoAndCheckerMaker {
var_->set_dispensable(true); var_->set_dispensable(true);
return *this; return *this;
} }
VariableBuilder &Reuse(const std::string &name) {
var_->set_reuse(name);
return *this;
}
}; };
VariableBuilder AddInput(const std::string &name, const std::string &comment); VariableBuilder AddInput(const std::string &name, const std::string &comment);
...@@ -85,8 +78,6 @@ class OpProtoAndCheckerMaker { ...@@ -85,8 +78,6 @@ class OpProtoAndCheckerMaker {
VariableBuilder AddOutput(const std::string &name, VariableBuilder AddOutput(const std::string &name,
const std::string &comment); const std::string &comment);
void Reuse(const std::string &name, const std::string &reused_name);
template <typename T> template <typename T>
TypedAttrChecker<T> &AddAttr(const std::string &name, TypedAttrChecker<T> &AddAttr(const std::string &name,
const std::string &comment, const std::string &comment,
...@@ -105,8 +96,6 @@ class OpProtoAndCheckerMaker { ...@@ -105,8 +96,6 @@ class OpProtoAndCheckerMaker {
void CheckNoDuplicatedInOutAttrs(); void CheckNoDuplicatedInOutAttrs();
void Validate(); void Validate();
void CheckReuseVars();
proto::OpProto *proto_; proto::OpProto *proto_;
OpAttrChecker *op_checker_; OpAttrChecker *op_checker_;
bool validated_{false}; bool validated_{false};
......
...@@ -47,120 +47,3 @@ TEST(ProtoMaker, DuplicatedInOut) { ...@@ -47,120 +47,3 @@ TEST(ProtoMaker, DuplicatedInOut) {
ASSERT_THROW(proto_maker(&op_proto, &op_checker), ASSERT_THROW(proto_maker(&op_proto, &op_checker),
paddle::platform::EnforceNotMet); paddle::platform::EnforceNotMet);
} }
class TestInplaceProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "input of test op");
AddOutput("XOut", "output of test op").Reuse("X");
}
};
class TestInplaceProtoMaker2
: public paddle::framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "input of test op");
AddOutput("XOut", "output of test op").Reuse("X");
AddOutput("NoOut", "output of test op").Reuse("NotExists");
}
};
TEST(ProtoMaker, InplaceOutput) {
paddle::framework::proto::OpProto op_proto, op_proto2;
paddle::framework::OpAttrChecker op_checker;
TestInplaceProtoMaker proto_maker;
TestInplaceProtoMaker2 proto_maker2;
proto_maker(&op_proto, &op_checker);
ASSERT_THROW(proto_maker2(&op_proto2, &op_checker),
paddle::platform::EnforceNotMet);
}
// normal reuse
class TestReuseProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "input of test op");
AddInput("Y", "input of test op");
AddOutput("Out", "output of test op");
AddOutput("XOut", "output of test op");
// avoid destructor exception.
// Validate();
TestReuse();
}
virtual void TestReuse() {}
};
// test duplicate reuse error
class TestReuseProtoMaker2 : public TestReuseProtoMaker {
public:
void TestReuse() {
Reuse("Out", "X");
Reuse("Out", "Y");
}
};
// NotExists Input
class TestReuseProtoMaker3 : public TestReuseProtoMaker {
public:
void TestReuse() {
Reuse("Out", "NotExists");
Reuse("XOut", "X");
}
};
// NotExists Output
class TestReuseProtoMaker4 : public TestReuseProtoMaker {
public:
void TestReuse() { Reuse("NotExists", "X"); }
};
TEST(ProtoMaker, Reuse) {
paddle::framework::proto::OpProto op_proto;
paddle::framework::OpAttrChecker op_checker;
TestReuseProtoMaker proto_maker;
proto_maker(&op_proto, &op_checker);
}
// NOTE(dzhwinter):
// There is a Fatal CHECK on base class destructor, which will call abort inside
// instead of
// throw an exception. If we throw an exception in Make(), we will trigger the
// CHECK and terminate the tests.
//
// I had tried to replace the default CHECK with a exception, however, it's
// still not supported by glog.
// the details:
// https://github.com/google/glog/issues/249
// https://github.com/facebookresearch/TensorComprehensions/issues/351
/*
TEST(ProtoMaker, ReuseWithException) {
paddle::framework::proto::OpProto op_proto2, op_proto3, op_proto4;
paddle::framework::OpAttrChecker op_checker;
TestReuseProtoMaker2 proto_maker2;
TestReuseProtoMaker3 proto_maker3;
TestReuseProtoMaker4 proto_maker4;
EXPECT_THROW(proto_maker2(&op_proto2, &op_checker),
paddle::platform::EnforceNotMet);
EXPECT_THROW(proto_maker3(&op_proto3, &op_checker),
paddle::platform::EnforceNotMet);
EXPECT_THROW(proto_maker4(&op_proto4, &op_checker),
paddle::platform::EnforceNotMet);
}
void FailureFunction() {
throw std::runtime_error("Check failed in destructor.");
// return 0;
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
google::InstallFailureFunction(&FailureFunction);
return RUN_ALL_TESTS();
}
*/
...@@ -156,12 +156,10 @@ ParallelExecutor::ParallelExecutor( ...@@ -156,12 +156,10 @@ ParallelExecutor::ParallelExecutor(
params, member_->local_scopes_, member_->use_cuda_); params, member_->local_scopes_, member_->use_cuda_);
#endif #endif
if (VLOG_IS_ON(5)) { // If the loss_var_name is given, the number of graph should be only one.
// If the loss_var_name is given, the number of graph should be only one. if (loss_var_name.size()) {
if (loss_var_name.size()) { PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1, "The number of graph should be only one");
"The number of graph should be only one");
}
} }
if (exec_strategy.type_ == ExecutionStrategy::kDefault) { if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
......
...@@ -103,7 +103,7 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -103,7 +103,7 @@ TEST(ProgramDesc, copy_ctor) {
ASSERT_EQ(1, op->GetBlockAttrId("sub_block")); ASSERT_EQ(1, op->GetBlockAttrId("sub_block"));
found_sub_block = true; found_sub_block = true;
ASSERT_EQ(2, op->GetBlocksAttrIds("sub_blocks").size()); ASSERT_EQ(2UL, op->GetBlocksAttrIds("sub_blocks").size());
found_sub_blocks = true; found_sub_blocks = true;
} }
} }
......
...@@ -40,7 +40,7 @@ TEST(READER, decorate_chain) { ...@@ -40,7 +40,7 @@ TEST(READER, decorate_chain) {
auto endpoints = root->GetEndPoints(); auto endpoints = root->GetEndPoints();
ASSERT_EQ(endpoints.size(), 2U); ASSERT_EQ(endpoints.size(), 2U);
ASSERT_NE(endpoints.count(end_point1.get()), 0UL); ASSERT_NE(endpoints.count(end_point1.get()), 0UL);
ASSERT_NE(endpoints.count(end_point2.get()), 0); ASSERT_NE(endpoints.count(end_point2.get()), 0UL);
} }
{ {
......
...@@ -21,7 +21,7 @@ else ...@@ -21,7 +21,7 @@ else
fi fi
USE_TENSORRT=OFF USE_TENSORRT=OFF
if [ [-d"$TENSORRT_INCLUDE_DIR"] -a [-d"$TENSORRT_LIB_DIR"] ]; then if [ -d "$TENSORRT_INCLUDE_DIR" -a -d "$TENSORRT_LIB_DIR" ]; then
USE_TENSORRT=ON USE_TENSORRT=ON
fi fi
......
...@@ -42,16 +42,22 @@ class Pool2dOpConverter : public OpConverter { ...@@ -42,16 +42,22 @@ class Pool2dOpConverter : public OpConverter {
boost::get<std::vector<int>>(op_desc.GetAttr("strides")); boost::get<std::vector<int>>(op_desc.GetAttr("strides"));
std::vector<int> paddings = std::vector<int> paddings =
boost::get<std::vector<int>>(op_desc.GetAttr("paddings")); boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
bool ceil_mode = boost::get<bool>(op_desc.GetAttr("ceil_mode"));
nvinfer1::Dims input_shape = input1->getDimensions();
int nbDims = input_shape.nbDims;
nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]); nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]);
nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
if (global_pooling == true) { if (global_pooling == true) {
nvinfer1::Dims input_shape = input1->getDimensions();
int nbDims = input_shape.nbDims;
nv_ksize.d[0] = input_shape.d[nbDims - 2]; nv_ksize.d[0] = input_shape.d[nbDims - 2];
nv_ksize.d[1] = input_shape.d[nbDims - 1]; nv_ksize.d[1] = input_shape.d[nbDims - 1];
nv_strides.h() = 1;
nv_strides.w() = 1;
nv_paddings.h() = 0;
nv_paddings.w() = 0;
} }
const nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
const nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
PADDLE_ENFORCE_EQ(input1->getDimensions().nbDims, 3UL); PADDLE_ENFORCE_EQ(input1->getDimensions().nbDims, 3UL);
...@@ -64,6 +70,36 @@ class Pool2dOpConverter : public OpConverter { ...@@ -64,6 +70,36 @@ class Pool2dOpConverter : public OpConverter {
PADDLE_THROW("TensorRT unsupported pooling type!"); PADDLE_THROW("TensorRT unsupported pooling type!");
} }
if (ceil_mode) {
nvinfer1::DimsHW pre_pad(0, 0);
nvinfer1::DimsHW post_pad(0, 0);
int input_height = input_shape.d[nbDims - 2];
int input_width = input_shape.d[nbDims - 1];
int floor_h_output_size =
(input_height - ksize[0] + 2 * paddings[0]) / strides[0] + 1;
int ceil_h_output_size =
(input_height - ksize[0] + 2 * paddings[0] + strides[0] - 1) /
strides[0] +
1;
int floor_w_output_size =
(input_width - ksize[1] + 2 * paddings[1]) / strides[1] + 1;
int ceil_w_output_size =
(input_width - ksize[1] + 2 * paddings[1] + strides[1] - 1) /
strides[1] +
1;
if (floor_h_output_size != ceil_h_output_size) {
post_pad.h() = strides[0] - 1;
}
if (floor_w_output_size != ceil_w_output_size) {
post_pad.w() = strides[1] - 1;
}
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, Padding, *const_cast<nvinfer1::ITensor*>(input1), pre_pad,
post_pad);
input1 = layer->getOutput(0);
}
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling,
*const_cast<nvinfer1::ITensor*>(input1), *const_cast<nvinfer1::ITensor*>(input1),
nv_pool_type, nv_ksize); nv_pool_type, nv_ksize);
......
...@@ -20,18 +20,20 @@ namespace paddle { ...@@ -20,18 +20,20 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
void test_pool2d(bool global_pooling) { void test_pool2d(bool global_pooling, bool ceil_mode) {
framework::Scope scope; framework::Scope scope;
std::unordered_set<std::string> parameters; std::unordered_set<std::string> parameters;
TRTConvertValidation validator(5, parameters, scope, 1 << 15); TRTConvertValidation validator(5, parameters, scope, 1 << 15);
// The ITensor's Dims should not contain the batch size. // The ITensor's Dims should not contain the batch size.
// So, the ITensor's Dims of input and output should be C * H * W. // So, the ITensor's Dims of input and output should be C * H * W.
validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 4, 4)); validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 13, 14));
if (global_pooling) if (global_pooling)
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1)); validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1));
else if (ceil_mode)
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 6, 7));
else else
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2)); validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 6, 6));
// Prepare Op description // Prepare Op description
framework::OpDesc desc; framework::OpDesc desc;
...@@ -39,7 +41,7 @@ void test_pool2d(bool global_pooling) { ...@@ -39,7 +41,7 @@ void test_pool2d(bool global_pooling) {
desc.SetInput("X", {"pool2d-X"}); desc.SetInput("X", {"pool2d-X"});
desc.SetOutput("Out", {"pool2d-Out"}); desc.SetOutput("Out", {"pool2d-Out"});
std::vector<int> ksize({2, 2}); std::vector<int> ksize({3, 3});
std::vector<int> strides({2, 2}); std::vector<int> strides({2, 2});
std::vector<int> paddings({0, 0}); std::vector<int> paddings({0, 0});
std::string pooling_t = "max"; std::string pooling_t = "max";
...@@ -49,6 +51,7 @@ void test_pool2d(bool global_pooling) { ...@@ -49,6 +51,7 @@ void test_pool2d(bool global_pooling) {
desc.SetAttr("strides", strides); desc.SetAttr("strides", strides);
desc.SetAttr("paddings", paddings); desc.SetAttr("paddings", paddings);
desc.SetAttr("global_pooling", global_pooling); desc.SetAttr("global_pooling", global_pooling);
desc.SetAttr("ceil_mode", ceil_mode);
LOG(INFO) << "set OP"; LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto()); validator.SetOp(*desc.Proto());
...@@ -57,9 +60,10 @@ void test_pool2d(bool global_pooling) { ...@@ -57,9 +60,10 @@ void test_pool2d(bool global_pooling) {
validator.Execute(3); validator.Execute(3);
} }
TEST(Pool2dOpConverter, normal) { test_pool2d(false); } TEST(Pool2dOpConverter, normal) { test_pool2d(false, false); }
TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true, false); }
TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true); } TEST(Pool2dOpConverter, test_ceil_mode) { test_pool2d(false, true); }
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
...@@ -71,7 +71,7 @@ void profile(bool use_mkldnn = false) { ...@@ -71,7 +71,7 @@ void profile(bool use_mkldnn = false) {
} }
TEST(Analyzer_resnet50, profile) { profile(); } TEST(Analyzer_resnet50, profile) { profile(); }
#ifndef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_resnet50, profile_mkldnn) { profile(true /* use_mkldnn */); } TEST(Analyzer_resnet50, profile_mkldnn) { profile(true /* use_mkldnn */); }
#endif #endif
......
...@@ -50,7 +50,7 @@ void CompareResult(const std::vector<PaddleTensor> &outputs, ...@@ -50,7 +50,7 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
auto &ref_out = ref_outputs[i]; auto &ref_out = ref_outputs[i];
size_t size = VecReduceToInt(out.shape); size_t size = VecReduceToInt(out.shape);
size_t ref_size = VecReduceToInt(ref_out.shape); size_t ref_size = VecReduceToInt(ref_out.shape);
EXPECT_GT(size, 0); EXPECT_GT(size, 0UL);
EXPECT_EQ(size, ref_size); EXPECT_EQ(size, ref_size);
EXPECT_EQ(out.dtype, ref_out.dtype); EXPECT_EQ(out.dtype, ref_out.dtype);
switch (out.dtype) { switch (out.dtype) {
......
...@@ -284,10 +284,10 @@ op_library(max_sequence_len_op DEPS lod_rank_table) ...@@ -284,10 +284,10 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
op_library(sequence_conv_op DEPS context_project) op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling) op_library(sequence_pool_op DEPS sequence_pooling)
if (NOT WIN32) if (NOT WIN32)
op_library(lstm_op DEPS sequence2batch lstm_compute) op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(hierarchical_sigmoid_op DEPS matrix_bit_code) op_library(hierarchical_sigmoid_op DEPS matrix_bit_code)
op_library(lstmp_op DEPS sequence2batch lstm_compute) op_library(lstmp_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute) op_library(gru_op DEPS sequence2batch gru_compute)
endif(NOT WIN32) endif(NOT WIN32)
op_library(recurrent_op DEPS executor) op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
...@@ -316,7 +316,7 @@ op_library(save_op DEPS lod_tensor) ...@@ -316,7 +316,7 @@ op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor) op_library(load_op DEPS lod_tensor)
op_library(save_combine_op DEPS lod_tensor) op_library(save_combine_op DEPS lod_tensor)
op_library(load_combine_op DEPS lod_tensor) op_library(load_combine_op DEPS lod_tensor)
op_library(concat_op DEPS concat) op_library(concat_op DEPS concat_and_split)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
...@@ -348,6 +348,6 @@ cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory) ...@@ -348,6 +348,6 @@ cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
if(NOT WIN32) if(NOT WIN32)
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context) nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
endif() endif()
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
...@@ -28,7 +28,7 @@ using paddle::framework::Tensor; ...@@ -28,7 +28,7 @@ using paddle::framework::Tensor;
public: \ public: \
void Make() override { \ void Make() override { \
AddInput("X", "Input of " #OP_NAME " operator"); \ AddInput("X", "Input of " #OP_NAME " operator"); \
AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X"); \ AddOutput("Out", "Output of " #OP_NAME " operator"); \
AddAttr<bool>("use_mkldnn", \ AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \ "(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \ .SetDefault(false); \
......
...@@ -92,9 +92,9 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -92,9 +92,9 @@ class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator"); AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator");
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator"); AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator");
AddOutput("ParamOut", "(Tensor) Output parameter").Reuse("Param"); AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("Moment1Out", "(Tensor) Output first moment").Reuse("Moment1"); AddOutput("Moment1Out", "(Tensor) Output first moment");
AddOutput("Moment2Out", "(Tensor) Output second moment").Reuse("Moment2"); AddOutput("Moment2Out", "(Tensor) Output second moment");
AddAttr<float>("beta1", AddAttr<float>("beta1",
"(float, default 0.9) " "(float, default 0.9) "
......
...@@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <paddle/fluid/operators/math/concat.h> #include <paddle/fluid/operators/math/concat_and_split.h>
#include <numeric> #include <numeric>
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
......
...@@ -135,15 +135,13 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -135,15 +135,13 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Variance", AddInput("Variance",
"The global variance (for training) " "The global variance (for training) "
"or estimated Variance (for testing)"); "or estimated Variance (for testing)");
AddOutput("Y", "result after normalization").Reuse("X"); AddOutput("Y", "result after normalization");
AddOutput("MeanOut", AddOutput("MeanOut",
"Share memory with Mean. " "Share memory with Mean. "
"Store the global mean when training") "Store the global mean when training");
.Reuse("Mean");
AddOutput("VarianceOut", AddOutput("VarianceOut",
"Share memory with Variance. " "Share memory with Variance. "
"Store the global Variance when training") "Store the global Variance when training");
.Reuse("Variance");
AddOutput("SavedMean", AddOutput("SavedMean",
"Mean of the current mini batch, " "Mean of the current mini batch, "
"will apply to output when training") "will apply to output when training")
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/strided_memcpy.h"
namespace paddle { namespace paddle {
...@@ -89,29 +89,17 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -89,29 +89,17 @@ class ConcatGradKernel : public framework::OpKernel<T> {
outputs.push_back(nullptr); outputs.push_back(nullptr);
} }
} }
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// Sometimes direct copies will be faster, this maybe need deeply analysis. // Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) { if (axis == 0 && outs.size() < 10) {
size_t input_offset = 0; std::vector<const framework::Tensor*> ref_shape;
const auto in_stride = framework::stride_numel(out_grad->dims()); ref_shape.insert(ref_shape.begin(), ins.begin(), ins.end());
StridedMemcpyWithAxis0<T>(dev_ctx, *out_grad, ref_shape, &outputs);
for (size_t i = 0; i < outs.size(); ++i) {
auto out_stride = framework::stride_numel(ins[i]->dims());
auto* out = outputs[i];
if (out != nullptr) {
StridedNumelCopyWithAxis<T>(
ctx.device_context(), axis, out->data<T>(), out_stride,
out_grad->data<T>() + input_offset, in_stride, out_stride[axis]);
}
input_offset += out_stride[axis];
}
} else { } else {
auto& dev_ctx = ctx.template device_context<DeviceContext>(); math::SplitFunctor<DeviceContext, T> split_functor;
paddle::operators::math::ConcatGradFunctor<DeviceContext, T> split_functor(dev_ctx, *out_grad, ctx.MultiInput<framework::Tensor>("X"),
concat_grad_functor; static_cast<int>(axis), &outputs);
concat_grad_functor(dev_ctx, *out_grad,
ctx.MultiInput<framework::Tensor>("X"),
static_cast<int>(axis), &outputs);
} }
} }
}; };
......
...@@ -130,8 +130,7 @@ void Conv2DOpMaker::Make() { ...@@ -130,8 +130,7 @@ void Conv2DOpMaker::Make() {
.AsDispensable(); .AsDispensable();
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution operator. " "(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.") "The format of output tensor is also NCHW.");
.Reuse("Input");
AddInput("ResidualData", AddInput("ResidualData",
"(Tensor) Tensor with residual data " "(Tensor) Tensor with residual data "
"to which convolution output will be added." "to which convolution output will be added."
...@@ -238,8 +237,7 @@ void Conv3DOpMaker::Make() { ...@@ -238,8 +237,7 @@ void Conv3DOpMaker::Make() {
"input image channels divided by the groups."); "input image channels divided by the groups.");
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution operator." "(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW.") "The format of output tensor is also NCDHW.");
.Reuse("Input");
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector<int>, default:{1, 1, 1}), the " "(vector<int>, default:{1, 1, 1}), the "
"strides(d_stride, h_stride, w_stride) of " "strides(d_stride, h_stride, w_stride) of "
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/bbox_util.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
......
...@@ -52,6 +52,9 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ...@@ -52,6 +52,9 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasOutput("TargetBBox"), ctx->HasOutput("TargetBBox"),
"Output(TargetBBox) of RpnTargetAssignOp should not be null"); "Output(TargetBBox) of RpnTargetAssignOp should not be null");
PADDLE_ENFORCE(
ctx->HasOutput("BBoxInsideWeight"),
"Output(BBoxInsideWeight) of RpnTargetAssignOp should not be null");
auto anchor_dims = ctx->GetInputDim("Anchor"); auto anchor_dims = ctx->GetInputDim("Anchor");
auto gt_boxes_dims = ctx->GetInputDim("GtBoxes"); auto gt_boxes_dims = ctx->GetInputDim("GtBoxes");
...@@ -68,6 +71,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ...@@ -68,6 +71,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ScoreIndex", {-1}); ctx->SetOutputDim("ScoreIndex", {-1});
ctx->SetOutputDim("TargetLabel", {-1, 1}); ctx->SetOutputDim("TargetLabel", {-1, 1});
ctx->SetOutputDim("TargetBBox", {-1, 4}); ctx->SetOutputDim("TargetBBox", {-1, 4});
ctx->SetOutputDim("BBoxInsideWeight", {-1, 4});
} }
protected: protected:
...@@ -169,6 +173,7 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, ...@@ -169,6 +173,7 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
const float rpn_positive_overlap, const float rpn_positive_overlap,
const float rpn_negative_overlap, std::vector<int>* fg_inds, const float rpn_negative_overlap, std::vector<int>* fg_inds,
std::vector<int>* bg_inds, std::vector<int>* tgt_lbl, std::vector<int>* bg_inds, std::vector<int>* tgt_lbl,
std::vector<int>* fg_fake, std::vector<T>* bbox_inside_weight,
std::minstd_rand engine, bool use_random) { std::minstd_rand engine, bool use_random) {
float epsilon = 0.00001; float epsilon = 0.00001;
int anchor_num = anchor_to_gt_max.dims()[0]; int anchor_num = anchor_to_gt_max.dims()[0];
...@@ -201,12 +206,12 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, ...@@ -201,12 +206,12 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
// Reservoir Sampling // Reservoir Sampling
int fg_num = static_cast<int>(rpn_fg_fraction * rpn_batch_size_per_im); int fg_num = static_cast<int>(rpn_fg_fraction * rpn_batch_size_per_im);
ReservoirSampling(fg_num, &fg_inds_fake, engine, use_random); ReservoirSampling(fg_num, &fg_inds_fake, engine, use_random);
fg_num = static_cast<int>(fg_inds_fake.size()); int fg_fake_num = static_cast<int>(fg_inds_fake.size());
for (int64_t i = 0; i < fg_num; ++i) { for (int64_t i = 0; i < fg_fake_num; ++i) {
target_label[fg_inds_fake[i]] = 1; target_label[fg_inds_fake[i]] = 1;
} }
int bg_num = rpn_batch_size_per_im - fg_num; int bg_num = rpn_batch_size_per_im - fg_fake_num;
for (int64_t i = 0; i < anchor_num; ++i) { for (int64_t i = 0; i < anchor_num; ++i) {
if (anchor_to_gt_max_data[i] < rpn_negative_overlap) { if (anchor_to_gt_max_data[i] < rpn_negative_overlap) {
bg_inds_fake.push_back(i); bg_inds_fake.push_back(i);
...@@ -214,12 +219,28 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data, ...@@ -214,12 +219,28 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
} }
ReservoirSampling(bg_num, &bg_inds_fake, engine, use_random); ReservoirSampling(bg_num, &bg_inds_fake, engine, use_random);
bg_num = static_cast<int>(bg_inds_fake.size()); bg_num = static_cast<int>(bg_inds_fake.size());
int fake_num = 0;
for (int64_t i = 0; i < bg_num; ++i) { for (int64_t i = 0; i < bg_num; ++i) {
// fg fake found
if (target_label[bg_inds_fake[i]] == 1) {
fake_num++;
fg_fake->emplace_back(fg_inds_fake[0]);
for (int j = 0; j < 4; ++j) {
bbox_inside_weight->emplace_back(T(0.));
}
}
target_label[bg_inds_fake[i]] = 0; target_label[bg_inds_fake[i]] = 0;
} }
for (int64_t i = 0; i < (fg_fake_num - fake_num) * 4; ++i) {
bbox_inside_weight->emplace_back(T(1.));
}
for (int64_t i = 0; i < anchor_num; ++i) { for (int64_t i = 0; i < anchor_num; ++i) {
if (target_label[i] == 1) fg_inds->emplace_back(i); if (target_label[i] == 1) {
fg_inds->emplace_back(i);
fg_fake->emplace_back(i);
}
if (target_label[i] == 0) bg_inds->emplace_back(i); if (target_label[i] == 0) bg_inds->emplace_back(i);
} }
fg_num = fg_inds->size(); fg_num = fg_inds->size();
...@@ -248,7 +269,8 @@ std::vector<Tensor> SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx, ...@@ -248,7 +269,8 @@ std::vector<Tensor> SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx,
std::vector<int> bg_inds; std::vector<int> bg_inds;
std::vector<int> gt_inds; std::vector<int> gt_inds;
std::vector<int> tgt_lbl; std::vector<int> tgt_lbl;
std::vector<int> fg_fake;
std::vector<T> bbox_inside_weight;
// Calculate the max IoU between anchors and gt boxes // Calculate the max IoU between anchors and gt boxes
// Map from anchor to gt box that has highest overlap // Map from anchor to gt box that has highest overlap
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
...@@ -275,32 +297,37 @@ std::vector<Tensor> SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx, ...@@ -275,32 +297,37 @@ std::vector<Tensor> SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx,
// Follow the Faster RCNN's implementation // Follow the Faster RCNN's implementation
ScoreAssign(anchor_by_gt_overlap_data, anchor_to_gt_max, gt_to_anchor_max, ScoreAssign(anchor_by_gt_overlap_data, anchor_to_gt_max, gt_to_anchor_max,
rpn_batch_size_per_im, rpn_fg_fraction, rpn_positive_overlap, rpn_batch_size_per_im, rpn_fg_fraction, rpn_positive_overlap,
rpn_negative_overlap, &fg_inds, &bg_inds, &tgt_lbl, engine, rpn_negative_overlap, &fg_inds, &bg_inds, &tgt_lbl, &fg_fake,
use_random); &bbox_inside_weight, engine, use_random);
int fg_num = fg_inds.size(); int fg_num = fg_inds.size();
int bg_num = bg_inds.size(); int bg_num = bg_inds.size();
gt_inds.reserve(fg_num); int fg_fake_num = fg_fake.size();
for (int i = 0; i < fg_num; ++i) { gt_inds.reserve(fg_fake_num);
gt_inds.emplace_back(argmax[fg_inds[i]]); for (int i = 0; i < fg_fake_num; ++i) {
gt_inds.emplace_back(argmax[fg_fake[i]]);
} }
Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t, bbox_inside_weight_t;
Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t; int* loc_index_data = loc_index_t.mutable_data<int>({fg_fake_num}, place);
int* loc_index_data = loc_index_t.mutable_data<int>({fg_num}, place);
int* score_index_data = int* score_index_data =
score_index_t.mutable_data<int>({fg_num + bg_num}, place); score_index_t.mutable_data<int>({fg_num + bg_num}, place);
int* tgt_lbl_data = tgt_lbl_t.mutable_data<int>({fg_num + bg_num}, place); int* tgt_lbl_data = tgt_lbl_t.mutable_data<int>({fg_num + bg_num}, place);
int* gt_inds_data = gt_inds_t.mutable_data<int>({fg_num}, place); int* gt_inds_data = gt_inds_t.mutable_data<int>({fg_fake_num}, place);
std::copy(fg_inds.begin(), fg_inds.end(), loc_index_data); T* bbox_inside_weight_data =
bbox_inside_weight_t.mutable_data<T>({fg_fake_num, 4}, place);
std::copy(fg_fake.begin(), fg_fake.end(), loc_index_data);
std::copy(fg_inds.begin(), fg_inds.end(), score_index_data); std::copy(fg_inds.begin(), fg_inds.end(), score_index_data);
std::copy(bg_inds.begin(), bg_inds.end(), score_index_data + fg_num); std::copy(bg_inds.begin(), bg_inds.end(), score_index_data + fg_num);
std::copy(tgt_lbl.begin(), tgt_lbl.end(), tgt_lbl_data); std::copy(tgt_lbl.begin(), tgt_lbl.end(), tgt_lbl_data);
std::copy(gt_inds.begin(), gt_inds.end(), gt_inds_data); std::copy(gt_inds.begin(), gt_inds.end(), gt_inds_data);
std::copy(bbox_inside_weight.begin(), bbox_inside_weight.end(),
bbox_inside_weight_data);
std::vector<Tensor> loc_score_tgtlbl_gt; std::vector<Tensor> loc_score_tgtlbl_gt;
loc_score_tgtlbl_gt.emplace_back(loc_index_t); loc_score_tgtlbl_gt.emplace_back(loc_index_t);
loc_score_tgtlbl_gt.emplace_back(score_index_t); loc_score_tgtlbl_gt.emplace_back(score_index_t);
loc_score_tgtlbl_gt.emplace_back(tgt_lbl_t); loc_score_tgtlbl_gt.emplace_back(tgt_lbl_t);
loc_score_tgtlbl_gt.emplace_back(gt_inds_t); loc_score_tgtlbl_gt.emplace_back(gt_inds_t);
loc_score_tgtlbl_gt.emplace_back(bbox_inside_weight_t);
return loc_score_tgtlbl_gt; return loc_score_tgtlbl_gt;
} }
...@@ -318,6 +345,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> { ...@@ -318,6 +345,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
auto* score_index = context.Output<LoDTensor>("ScoreIndex"); auto* score_index = context.Output<LoDTensor>("ScoreIndex");
auto* tgt_bbox = context.Output<LoDTensor>("TargetBBox"); auto* tgt_bbox = context.Output<LoDTensor>("TargetBBox");
auto* tgt_lbl = context.Output<LoDTensor>("TargetLabel"); auto* tgt_lbl = context.Output<LoDTensor>("TargetLabel");
auto* bbox_inside_weight = context.Output<LoDTensor>("BBoxInsideWeight");
PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL, PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL,
"RpnTargetAssignOp gt_boxes needs 1 level of LoD"); "RpnTargetAssignOp gt_boxes needs 1 level of LoD");
...@@ -340,7 +368,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> { ...@@ -340,7 +368,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
score_index->mutable_data<int>({max_num}, place); score_index->mutable_data<int>({max_num}, place);
tgt_bbox->mutable_data<T>({max_num, 4}, place); tgt_bbox->mutable_data<T>({max_num, 4}, place);
tgt_lbl->mutable_data<int>({max_num, 1}, place); tgt_lbl->mutable_data<int>({max_num, 1}, place);
bbox_inside_weight->mutable_data<T>({max_num, 4}, place);
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>(); auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
std::random_device rnd; std::random_device rnd;
...@@ -394,6 +422,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> { ...@@ -394,6 +422,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
Tensor sampled_score_index = loc_score_tgtlbl_gt[1]; Tensor sampled_score_index = loc_score_tgtlbl_gt[1];
Tensor sampled_tgtlbl = loc_score_tgtlbl_gt[2]; Tensor sampled_tgtlbl = loc_score_tgtlbl_gt[2];
Tensor sampled_gt_index = loc_score_tgtlbl_gt[3]; Tensor sampled_gt_index = loc_score_tgtlbl_gt[3];
Tensor sampled_bbox_inside_weight = loc_score_tgtlbl_gt[4];
int loc_num = sampled_loc_index.dims()[0]; int loc_num = sampled_loc_index.dims()[0];
int score_num = sampled_score_index.dims()[0]; int score_num = sampled_score_index.dims()[0];
...@@ -432,6 +461,8 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> { ...@@ -432,6 +461,8 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
AppendRpns<int>(score_index, total_score_num, &sampled_score_index_unmap); AppendRpns<int>(score_index, total_score_num, &sampled_score_index_unmap);
AppendRpns<T>(tgt_bbox, total_loc_num * 4, &sampled_tgt_bbox); AppendRpns<T>(tgt_bbox, total_loc_num * 4, &sampled_tgt_bbox);
AppendRpns<int>(tgt_lbl, total_score_num, &sampled_tgtlbl); AppendRpns<int>(tgt_lbl, total_score_num, &sampled_tgtlbl);
AppendRpns<T>(bbox_inside_weight, total_loc_num * 4,
&sampled_bbox_inside_weight);
total_loc_num += loc_num; total_loc_num += loc_num;
total_score_num += score_num; total_score_num += score_num;
...@@ -448,10 +479,12 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> { ...@@ -448,10 +479,12 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
score_index->set_lod(loc_score); score_index->set_lod(loc_score);
tgt_bbox->set_lod(lod_loc); tgt_bbox->set_lod(lod_loc);
tgt_lbl->set_lod(loc_score); tgt_lbl->set_lod(loc_score);
bbox_inside_weight->set_lod(lod_loc);
loc_index->Resize({total_loc_num}); loc_index->Resize({total_loc_num});
score_index->Resize({total_score_num}); score_index->Resize({total_score_num});
tgt_bbox->Resize({total_loc_num, 4}); tgt_bbox->Resize({total_loc_num, 4});
tgt_lbl->Resize({total_score_num, 1}); tgt_lbl->Resize({total_score_num, 1});
bbox_inside_weight->Resize({total_loc_num, 4});
} }
}; };
...@@ -514,6 +547,9 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -514,6 +547,9 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
"TargetLabel", "TargetLabel",
"(Tensor<int>), The target labels of each anchor with shape " "(Tensor<int>), The target labels of each anchor with shape "
"[F + B, 1], F and B are sampled foreground and backgroud number."); "[F + B, 1], F and B are sampled foreground and backgroud number.");
AddOutput("BBoxInsideWeight",
"(Tensor), The bbox inside weight with shape "
"[F, 4], F is the sampled foreground number.");
AddComment(R"DOC( AddComment(R"DOC(
This operator can be, for a given set of ground truth bboxes and the This operator can be, for a given set of ground truth bboxes and the
anchors, to assign classification and regression targets to each prediction. anchors, to assign classification and regression targets to each prediction.
......
...@@ -80,8 +80,6 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -80,8 +80,6 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() final { void Make() final {
AddInput("X", "(Tensor), The first input tensor of elementwise op."); AddInput("X", "(Tensor), The first input tensor of elementwise op.");
AddInput("Y", "(Tensor), The second input tensor of elementwise op."); AddInput("Y", "(Tensor), The second input tensor of elementwise op.");
// AddOutput("SavedShape", "(Tensor), save X, Y shape for grad to save
// memory.").AsIntermediate();
AddOutput("Out", "The output of elementwise op."); AddOutput("Out", "The output of elementwise op.");
AddAttr<int>("axis", AddAttr<int>("axis",
"(int, default -1). The start dimension index " "(int, default -1). The start dimension index "
...@@ -129,13 +127,11 @@ But the output only shares the LoD information with the input $X$. ...@@ -129,13 +127,11 @@ But the output only shares the LoD information with the input $X$.
)DOC", )DOC",
GetName(), GetEquation())); GetName(), GetEquation()));
SetReuse();
} }
protected: protected:
virtual std::string GetName() const = 0; virtual std::string GetName() const = 0;
virtual std::string GetEquation() const = 0; virtual std::string GetEquation() const = 0;
virtual void SetReuse() {}
}; };
class ElementwiseOpGrad : public framework::OperatorWithKernel { class ElementwiseOpGrad : public framework::OperatorWithKernel {
...@@ -269,7 +265,6 @@ class ElemwiseGradKernel : public framework::OpKernel<T> { ...@@ -269,7 +265,6 @@ class ElemwiseGradKernel : public framework::OpKernel<T> {
protected: \ protected: \
virtual std::string GetName() const { return op_name; } \ virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \ virtual std::string GetEquation() const { return equation; } \
virtual void SetReuse() { Reuse(__VA_ARGS__); } \
}; \ }; \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \ REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \ __ElemwiseOp##op_type##Maker__, \
......
...@@ -16,10 +16,9 @@ limitations under the License. */ ...@@ -16,10 +16,9 @@ limitations under the License. */
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> {
} }
} }
#define INIT_VEC_FUNC \ #define INIT_BASE_DEFINES \
std::function<void(const int, const T *, T *)> act_gate, act_state; \ auto* x = ctx.Input<LoDTensor>("X"); \
std::function<void(const int, const T*, const T*, const T*, T*)> cross; \ auto* wh = ctx.Input<Tensor>("WeightH"); \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \ auto* xx = ctx.Output<LoDTensor>("XX"); \
auto& act_state_str = ctx.Attr<std::string>("activation"); \ auto x_lod = x->lod(); \
if (platform::jit::MayIUse(platform::jit::avx)) { \ auto x_dims = x->dims(); /* T x M*/ \
math::VecActivations<T, platform::jit::avx> act_functor; \ auto wh_dims = wh->dims(); /* D x 3D*/ \
act_gate = act_functor(act_gate_str); \ const int total_T = x_dims[0]; \
act_state = act_functor(act_state_str); \ const int D3 = wh_dims[1]
cross = math::vec_cross<T, platform::jit::avx>; \
} else { \ #define INIT_OTHER_DEFINES \
math::VecActivations<T, platform::jit::isa_any> act_functor; \ auto* h0 = ctx.Input<Tensor>("H0"); \
act_gate = act_functor(act_gate_str); \ auto* wx = ctx.Input<Tensor>("WeightX"); \
act_state = act_functor(act_state_str); \ auto* bias = ctx.Input<Tensor>("Bias"); \
cross = math::vec_cross<T, platform::jit::isa_any>; \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
} bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \
#define INIT_BASE_INPUT_OUTPUT \ const int D = wh_dims[0]; \
auto* h0 = ctx.Input<Tensor>("H0"); \ const int D2 = D * 2; \
auto* wx = ctx.Input<Tensor>("WeightX"); \ const auto& ker = math::jitkernel::KernelPool::Instance() \
auto* wh = ctx.Input<Tensor>("WeightH"); \ .template Get<math::jitkernel::GRUKernel<T>, \
auto* bias = ctx.Input<Tensor>("Bias"); \ const std::string&, const std::string&>( \
auto* xx = ctx.Output<LoDTensor>("XX"); \ ctx.Attr<std::string>("gate_activation"), \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ ctx.Attr<std::string>("activation"), D); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
#define INIT_BASE_SIZES \ const T* wh_data = wh->data<T>(); \
auto x_dims = x->dims(); /* T x M*/ \ auto place = ctx.GetPlace(); \
auto wh_dims = wh->dims(); /* D x 3D*/ \ T* xx_data = xx->mutable_data<T>(place)
const int total_T = x_dims[0]; \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D3 = wh_dims[1]; \
const int D2 = D * 2;
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); INIT_BASE_DEFINES;
INIT_BASE_INPUT_OUTPUT INIT_OTHER_DEFINES;
INIT_BASE_SIZES
INIT_VEC_FUNC
auto x_lod = x->lod();
const int N = x_lod[0].size() - 1; const int N = x_lod[0].size() - 1;
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : nullptr; const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
const T* wh_state_data = wh_data + D * D2; const T* wh_state_data = wh_data + D * D2;
T* xx_data = xx->mutable_data<T>(ctx.GetPlace()); T* hidden_out_data = hidden_out->mutable_data<T>(place);
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data, math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
xx_data, xx_data,
...@@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
if (h0_data) { if (h0_data) {
prev_hidden_data = h0_data + bid * D; prev_hidden_data = h0_data + bid * D;
} else { } else {
// W: {W_update, W_reset; W_state} ker->ComputeH1(xx_data, hidden_out_data);
// update gate
act_gate(D, xx_data, xx_data);
// state gate
act_state(D, xx_data + D2, xx_data + D2);
// out = a*b
blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data);
// save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
tstart = 1; tstart = 1;
move_step(); move_step();
...@@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data, prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
D3); D3);
act_gate(D2, xx_data, xx_data); ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data);
// rt = rt*ht_1 inplace result
blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data);
// gemm rt * Ws // gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1), hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3); xx_data + D2, D3);
act_state(D, xx_data + D2, xx_data + D2); ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data);
// out = zt*ht~ + (1-zt)*ht_1
cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data);
// save prev // save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
move_step(); move_step();
...@@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const { void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); INIT_BASE_DEFINES;
INIT_BASE_INPUT_OUTPUT if (x_lod[0].size() == 2) {
INIT_BASE_SIZES
if (x->lod()[0].size() == 2) {
xx->Resize({total_T, D3}); xx->Resize({total_T, D3});
SeqCompute(ctx); SeqCompute(ctx);
return; return;
} }
INIT_VEC_FUNC INIT_OTHER_DEFINES;
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0"); auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput"); auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
auto* batched_out = ctx.Output<LoDTensor>("BatchedOut"); auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
T* batched_input_data = batched_input->mutable_data<T>(place);
const T* x_data = x->data<T>(); T* batched_out_data = batched_out->mutable_data<T>(place);
const T* wx_data = wx->data<T>(); hidden_out->mutable_data<T>(place);
const T* wh_data = wh->data<T>();
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* batched_input_data = batched_input->mutable_data<T>(ctx.GetPlace());
T* batched_out_data = batched_out->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
...@@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* prev_hidden_data = nullptr; T* prev_hidden_data = nullptr;
if (h0) { if (h0) {
// reorder h0 // reorder h0
T* reordered_h0_data = reordered_h0->mutable_data<T>(ctx.GetPlace()); T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
const T* h0_data = h0->data<T>(); const T* h0_data = h0->data<T>();
prev_hidden_data = reordered_h0_data; prev_hidden_data = reordered_h0_data;
size_t sz = sizeof(T) * D; size_t sz = sizeof(T) * D;
...@@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
// W: {W_update, W_reset; W_state} // W: {W_update, W_reset; W_state}
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
// update gate ker->ComputeH1(cur_in_data, cur_out_data);
act_gate(D, cur_in_data, cur_in_data);
// state gate
act_state(D, cur_in_data + D2, cur_in_data + D2);
// out = a*b
blas.VMUL(D, cur_in_data, cur_in_data + D2, cur_out_data);
// add offset // add offset
cur_in_data += D3; cur_in_data += D3;
cur_out_data += D; cur_out_data += D;
...@@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
T* cur_prev_hidden_data = prev_hidden_data; T* cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
act_gate(D2, cur_batched_data, cur_batched_data); ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data,
// rt = rt*ht_1 inplace result cur_out_data);
blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
...@@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
cur_prev_hidden_data = prev_hidden_data; cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
// ht~ = act_state(...) ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data,
act_state(D, cur_batched_data + D2, cur_batched_data + D2); cur_out_data);
// out = zt*ht~ + (1-zt)*ht_1
cross(D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data,
cur_out_data);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
...@@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
batched_out->set_lod(batched_lod); batched_out->set_lod(batched_lod);
to_seq(dev_ctx, *batched_out, hidden_out); to_seq(dev_ctx, *batched_out, hidden_out);
} }
#undef INIT_VEC_FUNC #undef INIT_OTHER_DEFINES
#undef INIT_BASE_SIZES #undef INIT_BASE_DEFINES
#undef INIT_BASE_INPUT_OUTPUT
}; };
} // namespace operators } // namespace operators
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
...@@ -79,7 +79,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor<void> { ...@@ -79,7 +79,7 @@ struct LoDTensorToArrayFunctor : public boost::static_visitor<void> {
template <typename DeviceContext> template <typename DeviceContext>
template <typename T> template <typename T>
void LoDTensorToArrayFunctorImpl<DeviceContext>::apply() { void LoDTensorToArrayFunctorImpl<DeviceContext>::apply() {
math::ConcatGradFunctor<DeviceContext, T> func; math::SplitFunctor<DeviceContext, T> func;
func(*dev_ctx_, prev_functor_->input_, prev_functor_->ref_inputs_, 0, func(*dev_ctx_, prev_functor_->input_, prev_functor_->ref_inputs_, 0,
&prev_functor_->outputs_); &prev_functor_->outputs_);
} }
......
if (NOT WIN32) if (NOT WIN32)
add_subdirectory(detail) add_subdirectory(detail)
endif(NOT WIN32) endif(NOT WIN32)
function(math_library TARGET) function(math_library TARGET)
...@@ -35,7 +35,7 @@ function(math_library TARGET) ...@@ -35,7 +35,7 @@ function(math_library TARGET)
endfunction() endfunction()
# please add new math_library in alphabetical order # please add new math_library in alphabetical order
math_library(concat) math_library(concat_and_split)
math_library(context_project DEPS im2col math_function) math_library(context_project DEPS im2col math_function)
math_library(cross_entropy) math_library(cross_entropy)
math_library(cos_sim_functor) math_library(cos_sim_functor)
...@@ -43,8 +43,8 @@ math_library(depthwise_conv) ...@@ -43,8 +43,8 @@ math_library(depthwise_conv)
math_library(im2col) math_library(im2col)
if (NOT WIN32) # windows do not support avx functions yet. if (NOT WIN32) # windows do not support avx functions yet.
math_library(gru_compute DEPS activation_functions math_function) math_library(gru_compute DEPS activation_functions math_function)
math_library(lstm_compute DEPS activation_functions) math_library(lstm_compute DEPS activation_functions)
endif (NOT WIN32) endif (NOT WIN32)
cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context)
...@@ -58,7 +58,7 @@ math_library(sequence_pooling DEPS math_function) ...@@ -58,7 +58,7 @@ math_library(sequence_pooling DEPS math_function)
math_library(sequence_scale) math_library(sequence_scale)
math_library(softmax DEPS math_function) math_library(softmax DEPS math_function)
if (NOT WIN32) if (NOT WIN32)
math_library(matrix_bit_code) math_library(matrix_bit_code)
endif (NOT WIN32) endif (NOT WIN32)
math_library(unpooling) math_library(unpooling)
math_library(vol2col) math_library(vol2col)
...@@ -68,13 +68,14 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec ...@@ -68,13 +68,14 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec
cc_test(im2col_test SRCS im2col_test.cc DEPS im2col) cc_test(im2col_test SRCS im2col_test.cc DEPS im2col)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col) cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col)
cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding) cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
cc_test(sequence_pooling_test SRCS sequence_pooling_test.cc DEPS sequence_pooling)
if(WITH_GPU) if(WITH_GPU)
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function) nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function)
nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function) nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function)
endif() endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat) cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel cc_library(jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc
DEPS cpu_info cblas) DEPS cpu_info cblas)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include <vector> #include <vector>
namespace paddle { namespace paddle {
...@@ -67,7 +67,7 @@ class ConcatFunctor<platform::CPUDeviceContext, T> { ...@@ -67,7 +67,7 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
* each dimension must be the same, except the axis dimension. * each dimension must be the same, except the axis dimension.
*/ */
template <typename T> template <typename T>
class ConcatGradFunctor<platform::CPUDeviceContext, T> { class SplitFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
...@@ -111,7 +111,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> { ...@@ -111,7 +111,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
}; };
#define DEFINE_FUNCTOR(type) \ #define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CPUDeviceContext, type>; \ template class ConcatFunctor<platform::CPUDeviceContext, type>; \
template class ConcatGradFunctor<platform::CPUDeviceContext, type>; template class SplitFunctor<platform::CPUDeviceContext, type>;
FOR_ALL_TYPES(DEFINE_FUNCTOR); FOR_ALL_TYPES(DEFINE_FUNCTOR);
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
...@@ -24,7 +24,7 @@ namespace operators { ...@@ -24,7 +24,7 @@ namespace operators {
namespace math { namespace math {
template <typename T> template <typename T>
__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, __global__ void ConcatKernel(T** inputs, const int* input_cols, int col_size,
const int output_rows, const int output_cols, const int output_rows, const int output_cols,
T* output) { T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -50,7 +50,7 @@ __global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, ...@@ -50,7 +50,7 @@ __global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
} }
template <typename T> template <typename T>
__global__ void KernelConcat(T** inputs_data, const int fixed_in_col, __global__ void ConcatKernel(T** inputs_data, const int fixed_in_col,
const int out_rows, const int out_cols, const int out_rows, const int out_cols,
T* output_data) { T* output_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -67,9 +67,9 @@ __global__ void KernelConcat(T** inputs_data, const int fixed_in_col, ...@@ -67,9 +67,9 @@ __global__ void KernelConcat(T** inputs_data, const int fixed_in_col,
} }
template <typename T> template <typename T>
__global__ void KernelConcatGrad(const T* input_data, const int in_row, __global__ void SplitKernel(const T* input_data, const int in_row,
const int in_col, const int* out_cols, const int in_col, const int* out_cols,
int out_cols_size, T** outputs_data) { int out_cols_size, T** outputs_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int curr_segment = 0; int curr_segment = 0;
int curr_offset = out_cols[0]; int curr_offset = out_cols[0];
...@@ -94,9 +94,9 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row, ...@@ -94,9 +94,9 @@ __global__ void KernelConcatGrad(const T* input_data, const int in_row,
} }
template <typename T> template <typename T>
__global__ void KernelConcatGrad(const T* input_data, const int in_row, __global__ void SplitKernel(const T* input_data, const int in_row,
const int in_col, const int fixed_out_col, const int in_col, const int fixed_out_col,
T** outputs_data) { T** outputs_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
int split = tid_x / fixed_out_col; int split = tid_x / fixed_out_col;
...@@ -170,11 +170,11 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -170,11 +170,11 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
dim3 grid_size = dim3(grid_cols, grid_rows, 1); dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (sameShape) { if (sameShape) {
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>( ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
dev_ins_data, in_col, out_row, out_col, output->data<T>()); dev_ins_data, in_col, out_row, out_col, output->data<T>());
} else { } else {
const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace()); const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace());
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>( ConcatKernel<<<grid_size, block_size, 0, context.stream()>>>(
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()), dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
out_row, out_col, output->data<T>()); out_row, out_col, output->data<T>());
} }
...@@ -189,7 +189,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { ...@@ -189,7 +189,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
* each dimension must be the same, except the axis dimension. * each dimension must be the same, except the axis dimension.
*/ */
template <typename T> template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> { class SplitFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
...@@ -248,11 +248,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -248,11 +248,11 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
dim3 grid_size = dim3(grid_cols, grid_rows, 1); dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (sameShape) { if (sameShape) {
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>( SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data); input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
} else { } else {
const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace()); const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace());
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>( SplitKernel<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, dev_outs_col_data, input.data<T>(), in_row, in_col, dev_outs_col_data,
static_cast<int>(outputs_cols.size()), dev_out_gpu_data); static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
} }
...@@ -264,7 +264,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -264,7 +264,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
#define DEFINE_FUNCTOR(type) \ #define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<platform::CUDADeviceContext, type>; \ template class ConcatFunctor<platform::CUDADeviceContext, type>; \
template class ConcatGradFunctor<platform::CUDADeviceContext, type> template class SplitFunctor<platform::CUDADeviceContext, type>
FOR_ALL_TYPES(DEFINE_FUNCTOR); FOR_ALL_TYPES(DEFINE_FUNCTOR);
......
...@@ -54,7 +54,7 @@ class ConcatFunctor { ...@@ -54,7 +54,7 @@ class ConcatFunctor {
* Output[1] = [[5,6]] * Output[1] = [[5,6]]
*/ */
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ConcatGradFunctor { class SplitFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs, const std::vector<const framework::Tensor*>& ref_inputs,
......
...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/concat.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
template <typename DeviceContext, typename Place> template <typename DeviceContext, typename Place>
void testConcat() { void testConcat() {
......
...@@ -142,6 +142,15 @@ class LSTMKernel : public Kernel { ...@@ -142,6 +142,15 @@ class LSTMKernel : public Kernel {
const T *wp_data = nullptr) const = 0; const T *wp_data = nullptr) const = 0;
}; };
template <typename T>
class GRUKernel : public Kernel {
public:
// compute h1 without h0
virtual void ComputeH1(T *gates, T *ht) const = 0;
virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0;
virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0;
};
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -136,6 +136,23 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel( ...@@ -136,6 +136,23 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
return nullptr; return nullptr;
} }
#ifdef __AVX__
template <jit::cpu_isa_t isa>
static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
if (type == "sigmoid") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>());
} else if (type == "relu") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>());
} else if (type == "tanh") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>());
} else if (type == "identity" || type == "") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>());
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
#endif
/* LSTM JitKernel */ /* LSTM JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
class LSTMKernelImpl : public LSTMKernel<T> { class LSTMKernelImpl : public LSTMKernel<T> {
...@@ -192,61 +209,49 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -192,61 +209,49 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#endif #endif
}; };
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \ LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
const std::string& act_gate, const std::string& act_cand, \ const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d) \ const std::string& act_cell, int d) \
: LSTMKernel<float>() { \ : LSTMKernel<float>() { \
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \ avx_act_gate_ = GetAVXAct<isa>(act_gate); \
if (type == "sigmoid") { \ avx_act_cand_ = GetAVXAct<isa>(act_cand); \
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \ avx_act_cell_ = GetAVXAct<isa>(act_cell); \
} else if (type == "relu") { \ } \
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \ template <> \
} else if (type == "tanh") { \ void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \ float* gates, const float* ct_1, float* ct, float* ht, \
} else if (type == "identity" || type == "") { \ const float* wp_data, float* checked) const { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \ /* gates: W_ch, W_ih, W_fh, W_oh */ \
} \ __m256 c, i, f, o; \
PADDLE_THROW("Not support type: %s", type); \ c = _mm256_loadu_ps(gates); \
}; \ i = _mm256_loadu_ps(gates + 8); \
avx_act_gate_ = GetAVXAct(act_gate); \ f = _mm256_loadu_ps(gates + 16); \
avx_act_cand_ = GetAVXAct(act_cand); \ o = _mm256_loadu_ps(gates + 24); \
avx_act_cell_ = GetAVXAct(act_cell); \ /* C_t = C_t-1 * fgated + cand_gated * igated*/ \
} \ c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
template <> \ i = _mm256_loadu_ps(ct_1); \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \ f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
float* gates, const float* ct_1, float* ct, float* ht, \ f = _mm256_add_ps(c, f); \
const float* wp_data, float* checked) const { \ _mm256_storeu_ps(ct, f); \
/* gates: W_ch, W_ih, W_fh, W_oh */ \ /* H_t = act_cell(C_t) * ogated */ \
__m256 c, i, f, o; \ o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
c = _mm256_loadu_ps(gates); \ _mm256_storeu_ps(ht, o); \
i = _mm256_loadu_ps(gates + 8); \ } \
f = _mm256_loadu_ps(gates + 16); \ template <> \
o = _mm256_loadu_ps(gates + 24); \ void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \ float* gates, float* ct, float* ht, const float* wp_data) const { \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ __m256 c, i, o; \
i = _mm256_loadu_ps(ct_1); \ c = _mm256_loadu_ps(gates); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ i = _mm256_loadu_ps(gates + 8); \
f = _mm256_add_ps(c, f); \ o = _mm256_loadu_ps(gates + 24); \
_mm256_storeu_ps(ct, f); \ /* C_t = igated * cgated*/ \
/* H_t = act_cell(C_t) * ogated */ \ c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ _mm256_storeu_ps(ct, c); \
_mm256_storeu_ps(ht, o); \ /* H_t = act_cell(C_t) * ogated */ \
} \ o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
template <> \ _mm256_storeu_ps(ht, o); \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/ \
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} }
// TODO(TJ): optimize keq16 // TODO(TJ): optimize keq16
...@@ -354,6 +359,126 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, ...@@ -354,6 +359,126 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
#undef JITKERNEL_DECLARE_LSTM #undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_KEY_LSTM #undef JITKERNEL_KEY_LSTM
#undef JITKERNEL_NEW_LSTM_IMPL #undef JITKERNEL_NEW_LSTM_IMPL
/* GRU JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class GRUKernelImpl : public GRUKernel<T> {
public:
explicit GRUKernelImpl(const std::string& act_gate,
const std::string& act_state, int d)
: GRUKernel<T>() {
d_ = d;
d2_ = d * 2;
act_gate_d2_ = GetActKernel<T>(act_gate, d2_);
act_gate_d_ = GetActKernel<T>(act_gate, d);
act_state_d_ = GetActKernel<T>(act_state, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
}
void ComputeH1(T* gates, T* ht) const override {
act_gate_d_->Compute(gates, gates);
act_state_d_->Compute(gates + d2_, gates + d2_);
vmul_d_->Compute(gates, gates + d2_, ht);
}
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
// W: {W_update, W_reset; W_state}
act_gate_d2_->Compute(gates, gates);
vmul_d_->Compute(ht_1, gates + d_, ht);
}
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
T* y = gates + d2_;
act_state_d_->Compute(y, y);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d_; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
}
}
private:
int d_, d2_;
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_state_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
#ifdef __AVX__
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_state_;
#endif
};
#define INTRI8_FLOAT(isa) \
template <> \
GRUKernelImpl<float, isa, kEQ8>::GRUKernelImpl( \
const std::string& act_gate, const std::string& act_state, int d) \
: GRUKernel<float>() { \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
avx_act_state_ = GetAVXAct<isa>(act_state); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeH1(float* gates, float* ht) \
const { \
__m256 u, s; \
/* W: {W_update, W_reset; W_state} */ \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \
_mm256_storeu_ps(ht, s); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart1( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */ \
__m256 r, ht0; \
r = _mm256_loadu_ps(gates + 8); \
ht0 = _mm256_loadu_ps(ht_1); \
r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \
_mm256_storeu_ps(ht, r); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart2( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */ \
__m256 u, s, ht0; \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
ht0 = _mm256_loadu_ps(ht_1); \
u = avx_act_gate_->Compute(u); \
s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \
u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \
u = _mm256_mul_ps(u, ht0); \
u = _mm256_add_ps(s, u); \
_mm256_storeu_ps(ht, u); \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
#endif
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \
GRUKernel<ker_dtype>, const std::string&, const std::string&, int>( \
const std::string& act_gate, const std::string& act_state, int d)
#define JITKERNEL_KEY_GRU(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_state
#define JITKERNEL_NEW_GRU_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_state, d));
REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU,
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL);
#undef INTRI8_FLOAT
#undef JITKERNEL_NEW_GRU_IMPL
#undef JITKERNEL_KEY_GRU
#undef JITKERNEL_DECLARE_GRU
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -157,6 +157,31 @@ class FirstSeqPoolFunctor { ...@@ -157,6 +157,31 @@ class FirstSeqPoolFunctor {
} }
}; };
template <typename T>
class SumSeqPoolGradFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& out_grad,
framework::LoDTensor* in_grad) {
auto lod = in_grad->lod()[0];
int64_t out_w = out_grad.numel() / out_grad.dims()[0];
int64_t in_w = in_grad->numel() / in_grad->dims()[0];
PADDLE_ENFORCE(in_w == out_w);
const T* out_g_data = out_grad.data<T>();
T* in_g_data = in_grad->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t in_offset = lod[i] * in_w;
const T* out_pos = out_g_data + i * out_w;
T* in_pos = in_g_data + in_offset;
for (int r = 0; r != h; ++r) {
blas.VCOPY(in_w, out_pos, in_pos + r * in_w);
}
}
}
};
template <typename T> template <typename T>
class SequencePoolFunctor<platform::CPUDeviceContext, T> { class SequencePoolFunctor<platform::CPUDeviceContext, T> {
public: public:
...@@ -231,9 +256,15 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> { ...@@ -231,9 +256,15 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
math::SetConstant<platform::CPUDeviceContext, T> functor; math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(context, in_grad, 0); functor(context, in_grad, 0);
} }
if (pooltype == "SUM") {
math::SumSeqPoolGradFunctor<T> sum_pool_grad;
sum_pool_grad(context, out_grad, in_grad);
return;
}
auto lod = in_grad->lod()[0]; auto lod = in_grad->lod()[0];
auto& place = *context.eigen_device(); auto& place = *context.eigen_device();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]), auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]),
static_cast<int>(lod[i + 1])); static_cast<int>(lod[i + 1]));
...@@ -247,12 +278,6 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> { ...@@ -247,12 +278,6 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
if (pooltype == "AVERAGE") { if (pooltype == "AVERAGE") {
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast); in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
} else if (pooltype == "SUM") {
const T* out_g_data = out_g_t.data<T>();
T* in_g_data = in_g_t.mutable_data<T>(context.GetPlace());
for (int r = 0; r != h; ++r) {
blas.VCOPY(w, out_g_data, in_g_data + r * w);
}
} else if (pooltype == "SQRT") { } else if (pooltype == "SQRT") {
in_g_e.device(place) = in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast); (out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/sequence_pooling.h"
#include <gtest/gtest.h>
#include <vector>
template <typename DeviceContext, typename Place, typename T>
void TestSequencePoolingSum(const paddle::framework::LoD& lod) {
paddle::framework::LoDTensor cpu_out_grad;
paddle::framework::LoDTensor cpu_in_grad;
paddle::framework::LoDTensor out_grad;
paddle::framework::LoDTensor in_grad;
const size_t second_dim = 128u;
// construct out_grad's tensor in cpu
const size_t out_first_dim = lod[0].size() - 1;
auto out_dims = paddle::framework::make_ddim(
{static_cast<int64_t>(out_first_dim), static_cast<int64_t>(second_dim)});
cpu_out_grad.mutable_data<T>(out_dims, paddle::platform::CPUPlace());
for (int64_t i = 0; i < cpu_out_grad.numel(); ++i) {
cpu_out_grad.data<T>()[i] = static_cast<T>(i);
}
// copy to dst out_grad
auto* place = new Place();
DeviceContext* context = new DeviceContext(*place);
if (paddle::platform::is_cpu_place(*place)) {
out_grad = cpu_out_grad;
} else {
TensorCopySync(cpu_out_grad, *place, &out_grad);
}
// construct in_grad
in_grad.set_lod(lod);
auto in_dims = paddle::framework::make_ddim(
{static_cast<int64_t>(lod[0].back()), static_cast<int64_t>(second_dim)});
in_grad.mutable_data<T>(in_dims, context->GetPlace());
// check tensor contruction result
PADDLE_ENFORCE_EQ(in_grad.dims().size(), out_grad.dims().size());
for (int64_t i = 1; i < out_grad.dims().size(); ++i) {
PADDLE_ENFORCE_EQ(in_grad.dims()[i], out_grad.dims()[i]);
}
// call functor
paddle::operators::math::SequencePoolGradFunctor<DeviceContext, T>()(
*context, "SUM", out_grad, &in_grad);
if (paddle::platform::is_cpu_place(*place)) {
cpu_in_grad = in_grad;
} else {
TensorCopySync(in_grad, paddle::platform::CPUPlace(), &cpu_in_grad);
cpu_in_grad.set_lod(in_grad.lod());
}
EXPECT_EQ(in_grad.numel(), lod[0].back() * second_dim);
EXPECT_EQ(in_grad.lod(), lod);
if (paddle::platform::is_cpu_place(*place)) {
for (int64_t i = 0; i < in_grad.lod()[0].size() - 1; ++i) {
int64_t begin = in_grad.lod()[0][i];
int64_t end = in_grad.lod()[0][i + 1];
paddle::framework::Tensor tmp = in_grad.Slice(begin, end);
for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) {
for (int64_t m = 0; m != second_dim; ++m) {
EXPECT_EQ(tmp.data<T>()[m + j * second_dim],
out_grad.data<T>()[m + i * second_dim]);
}
}
}
} else {
for (int64_t i = 0; i < cpu_in_grad.lod()[0].size() - 1; ++i) {
int64_t begin = cpu_in_grad.lod()[0][i];
int64_t end = cpu_in_grad.lod()[0][i + 1];
paddle::framework::Tensor tmp = cpu_in_grad.Slice(begin, end);
for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) {
for (int64_t m = 0; m != second_dim; ++m) {
EXPECT_EQ(tmp.data<T>()[m + j * second_dim],
cpu_out_grad.data<T>()[m + i * second_dim]);
}
}
}
}
delete place;
delete context;
}
TEST(SequencePoolingGrad, CPU_SUM) {
paddle::framework::LoD lod1;
lod1.push_back(std::vector<size_t>{0, 10});
TestSequencePoolingSum<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace, float>(lod1);
paddle::framework::LoD lod2;
lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
TestSequencePoolingSum<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace, float>(lod2);
}
#ifdef PADDLE_WITH_CUDA
TEST(SequencePoolingGrad, CUDA_SUM) {
paddle::framework::LoD lod1;
lod1.push_back(std::vector<size_t>{0, 10});
TestSequencePoolingSum<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace, float>(lod1);
paddle::framework::LoD lod2;
lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
TestSequencePoolingSum<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace, float>(lod2);
}
#endif
...@@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -34,7 +34,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor) The input of mean op"); AddInput("X", "(Tensor) The input of mean op");
AddOutput("Out", "(Tensor) The output of mean op").Reuse("X"); AddOutput("Out", "(Tensor) The output of mean op");
AddComment(R"DOC( AddComment(R"DOC(
Mean Operator calculates the mean of all elements in X. Mean Operator calculates the mean of all elements in X.
......
...@@ -151,8 +151,7 @@ void Pool2dOpMaker::Make() { ...@@ -151,8 +151,7 @@ void Pool2dOpMaker::Make() {
"The format of output tensor is also NCHW, " "The format of output tensor is also NCHW, "
"where N is batch size, C is the number of channels, " "where N is batch size, C is the number of channels, "
"H is the height of the feature, " "H is the height of the feature, "
"and W is the width of the feature.") "and W is the width of the feature.");
.Reuse("X");
AddAttr<std::string>("pooling_type", AddAttr<std::string>("pooling_type",
"(string), pooling type, can be \"max\" for max-pooling " "(string), pooling type, can be \"max\" for max-pooling "
...@@ -252,8 +251,7 @@ void Pool3dOpMaker::Make() { ...@@ -252,8 +251,7 @@ void Pool3dOpMaker::Make() {
"The format of output tensor is also NCDHW, " "The format of output tensor is also NCDHW, "
"where N is batch size, C is " "where N is batch size, C is "
"the number of channels, and D, H and W is the depth, height and " "the number of channels, and D, H and W is the depth, height and "
"width of the feature, respectively.") "width of the feature, respectively.");
.Reuse("X");
AddAttr<std::string>("pooling_type", AddAttr<std::string>("pooling_type",
"(string) Pooling type, can be \"max\" for max-pooling " "(string) Pooling type, can be \"max\" for max-pooling "
......
...@@ -237,7 +237,7 @@ TEST(BlockingQueue, speed_test_mode) { ...@@ -237,7 +237,7 @@ TEST(BlockingQueue, speed_test_mode) {
} }
for (size_t i = 0; i < queue_size; ++i) { for (size_t i = 0; i < queue_size; ++i) {
q2.Receive(&b); q2.Receive(&b);
EXPECT_EQ(b, 0); EXPECT_EQ(b, 0UL);
} }
EXPECT_EQ(q2.Size(), queue_size); EXPECT_EQ(q2.Size(), queue_size);
} }
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/operators/math/concat_and_split.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -106,7 +106,7 @@ class SeqConcatGradKernel : public framework::OpKernel<T> { ...@@ -106,7 +106,7 @@ class SeqConcatGradKernel : public framework::OpKernel<T> {
} }
} }
math::ConcatGradFunctor<DeviceContext, T> functor; math::SplitFunctor<DeviceContext, T> functor;
std::vector<const framework::Tensor *> sliced_x_ptr; std::vector<const framework::Tensor *> sliced_x_ptr;
std::vector<framework::Tensor *> sliced_dx_ptr; std::vector<framework::Tensor *> sliced_dx_ptr;
for (auto &x : sliced_x) { for (auto &x : sliced_x) {
......
...@@ -77,8 +77,7 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -77,8 +77,7 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Grad", "(Tensor or SelectedRows) Input gradient"); AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
AddOutput("ParamOut", AddOutput("ParamOut",
"(Tensor or SelectedRows, same with Param) " "(Tensor or SelectedRows, same with Param) "
"Output parameter, should share the same memory with Param") "Output parameter, should share the same memory with Param");
.Reuse("Param");
AddComment(R"DOC( AddComment(R"DOC(
SGD operator SGD operator
......
...@@ -80,8 +80,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -80,8 +80,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", AddInput("X",
"The input tensor of softmax, " "The input tensor of softmax, "
"whose last dimension is the input_feature_dimensions."); "whose last dimension is the input_feature_dimensions.");
AddOutput("Out", "The normalized values with the same shape as X.") AddOutput("Out", "The normalized values with the same shape as X.");
.Reuse("X");
AddAttr<bool>( AddAttr<bool>(
"use_cudnn", "use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn") "(bool, default false) Only used in cudnn kernel, need install cudnn")
......
...@@ -111,11 +111,10 @@ Example: ...@@ -111,11 +111,10 @@ Example:
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
USE_CPU_ONLY_OP(concat);
REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker, ops::SplitGradMaker); REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker, ops::SplitGradMaker);
REGISTER_OP_CPU_KERNEL(split, REGISTER_OP_CPU_KERNEL(
ops::SplitOpKernel<paddle::platform::CPUPlace, double>, split, ops::SplitOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SplitOpKernel<paddle::platform::CPUPlace, float>, ops::SplitOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SplitOpKernel<paddle::platform::CPUPlace, int64_t>, ops::SplitOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SplitOpKernel<paddle::platform::CPUPlace, int>); ops::SplitOpKernel<paddle::platform::CPUDeviceContext, int>);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <chrono> // NOLINT #include <chrono> // NOLINT
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/strided_memcpy.h"
namespace paddle { namespace paddle {
...@@ -28,18 +29,22 @@ class SplitOpKernel : public framework::OpKernel<T> { ...@@ -28,18 +29,22 @@ class SplitOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out"); auto outs = ctx.MultiOutput<framework::Tensor>("Out");
auto in_stride = framework::stride_numel(in->dims()); int axis = ctx.Attr<int>("axis");
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
size_t input_offset = 0; std::vector<const framework::Tensor*> shape_refer;
for (auto& out : outs) { for (size_t j = 0; j < outs.size(); ++j) {
out->mutable_data<T>(ctx.GetPlace()); outs[j]->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride_numel(out->dims()); shape_refer.emplace_back(outs[j]);
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(), }
out_stride, in->data<T>() + input_offset,
in_stride, out_stride[axis]); auto& dev_ctx = ctx.template device_context<DeviceContext>();
input_offset += out_stride[axis]; // Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) {
StridedMemcpyWithAxis0<T>(dev_ctx, *in, shape_refer, &outs);
} else {
math::SplitFunctor<DeviceContext, T> functor;
functor(dev_ctx, *in, shape_refer, axis, &outs);
} }
} }
}; };
......
...@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/detail/strided_memcpy.h" #include "paddle/fluid/operators/detail/strided_memcpy.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -98,5 +99,26 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx, ...@@ -98,5 +99,26 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
} }
} }
template <typename T>
inline void StridedMemcpyWithAxis0(
const platform::DeviceContext& dev_ctx, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& shape_refer,
std::vector<framework::Tensor*>* outputs) {
const framework::DDim in_stride = stride_numel(input.dims());
const int axis = 0;
size_t input_offset = 0;
for (size_t i = 0; i < outputs->size(); ++i) {
auto out_stride = stride_numel(shape_refer[i]->dims());
auto out = outputs->at(i);
if (out != nullptr) {
StridedNumelCopyWithAxis<T>(dev_ctx, axis, out->data<T>(), out_stride,
input.data<T>() + input_offset, in_stride,
out_stride[axis]);
}
input_offset += out_stride[axis];
}
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -132,7 +132,7 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -132,7 +132,7 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", "(vector<Tensor>) The input tensors of sum operator.") AddInput("X", "(vector<Tensor>) The input tensors of sum operator.")
.AsDuplicable(); .AsDuplicable();
AddOutput("Out", "(Tensor) The output tensor of sum operator.").Reuse("X"); AddOutput("Out", "(Tensor) The output tensor of sum operator.");
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
......
...@@ -50,7 +50,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -50,7 +50,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor) The input of Topk op"); AddInput("X", "(Tensor) The input of Topk op");
AddOutput("Out", "(Tensor) The output tensor of Topk op").Reuse("X"); AddOutput("Out", "(Tensor) The output tensor of Topk op");
AddOutput("Indices", "(Tensor) The indices of Topk elements of input"); AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
AddComment(R"DOC( AddComment(R"DOC(
Top K operator Top K operator
......
...@@ -262,31 +262,31 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, ...@@ -262,31 +262,31 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
const T* src, int lds, int dim, int k, const T* src, int lds, int dim, int k,
int grid_dim, int num) { int grid_dim, int num) {
__shared__ Pair<T> sh_topk[BlockSize]; __shared__ Pair<T> sh_topk[BlockSize];
__shared__ int maxid[BlockSize / 2];
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int warp = threadIdx.x / 32; const int warp = threadIdx.x / 32;
const int bid = blockIdx.x; const int bid = blockIdx.x;
for (int i = bid; i < num; i += grid_dim) { for (int i = bid; i < num; i += grid_dim) {
output += i * output_stride; int top_num = k;
indices += i * k; __shared__ int maxid[BlockSize / 2];
T* out = output + i * output_stride;
int64_t* inds = indices + i * k;
Pair<T> topk[MaxLength]; Pair<T> topk[MaxLength];
int beam = MaxLength; int beam = MaxLength;
Pair<T> max; Pair<T> max;
bool is_empty = false; bool is_empty = false;
bool firststep = true; bool firststep = true;
for (int k = 0; k < MaxLength; k++) { for (int j = 0; j < MaxLength; j++) {
topk[k].set(-INFINITY, -1); topk[j].set(-INFINITY, -1);
} }
while (k) { while (top_num) {
ThreadGetTopK<T, MaxLength, BlockSize>( ThreadGetTopK<T, MaxLength, BlockSize>(
topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid); topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
sh_topk[tid] = topk[0]; sh_topk[tid] = topk[0];
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output, BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &out, &inds,
&indices, &beam, &k, tid, warp); &beam, &top_num, tid, warp);
} }
} }
} }
...@@ -327,13 +327,15 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -327,13 +327,15 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
size_t k = static_cast<int>(ctx.Attr<int>("k")); size_t k = static_cast<int>(ctx.Attr<int>("k"));
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
// FIXME(typhoonzero): data is always converted to type T? // FIXME(typhoonzero): data is always converted to type T?
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
size_t input_height = input->dims()[0]; framework::DDim inputdims = input->dims();
size_t input_width = input->dims()[1]; const size_t input_height = framework::product(
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
const size_t input_width = inputdims[inputdims.size() - 1];
if (k > input_width) k = input_width; if (k > input_width) k = input_width;
// NOTE: pass lds and dim same to input width. // NOTE: pass lds and dim same to input width.
...@@ -342,14 +344,12 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -342,14 +344,12 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
const int kMaxHeight = 2048; const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
auto& dev_ctx = ctx.cuda_device_context(); auto& dev_ctx = ctx.cuda_device_context();
switch (GetDesiredBlockDim(input_width)) { switch (GetDesiredBlockDim(input_width)) {
FIXED_BLOCK_DIM( FIXED_BLOCK_DIM(
KeMatrixTopK<T, 5, KeMatrixTopK<T, 5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>( kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
output_data, output->dims()[1], indices_data, input_data, output_data, k, indices_data, input_data, input_width,
input_width, input_width, static_cast<int>(k), gridx, input_width, static_cast<int>(k), gridx, input_height));
input_height));
default: default:
PADDLE_THROW("Error"); PADDLE_THROW("Error");
} }
......
...@@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
// Get the top k elements of each row of input tensor // Get the top k elements of each row of input tensor
// FIXME: only deal with matrix(2d tensor).
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices"); auto* indices = ctx.Output<Tensor>("Indices");
...@@ -44,8 +43,6 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -44,8 +43,6 @@ class TopkKernel : public framework::OpKernel<T> {
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
auto eg_input = EigenMatrix<T>::From(*input);
// reshape input to a flattern matrix(like flat_inner_dims) // reshape input to a flattern matrix(like flat_inner_dims)
framework::DDim inputdims = input->dims(); framework::DDim inputdims = input->dims();
const size_t row = framework::product( const size_t row = framework::product(
...@@ -53,7 +50,7 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -53,7 +50,7 @@ class TopkKernel : public framework::OpKernel<T> {
const size_t col = inputdims[inputdims.size() - 1]; const size_t col = inputdims[inputdims.size() - 1];
Eigen::DSizes<int, 2> flat2dims(row, col); Eigen::DSizes<int, 2> flat2dims(row, col);
// NOTE: eigen shape doesn't affect paddle tensor. // NOTE: eigen shape doesn't affect paddle tensor.
eg_input.reshape(flat2dims); auto eg_input = EigenMatrix<T>::Reshape(*input, inputdims.size() - 1);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
......
...@@ -324,10 +324,19 @@ class LayerHelper(object): ...@@ -324,10 +324,19 @@ class LayerHelper(object):
raise ValueError("no Parameter name %s found" % name) raise ValueError("no Parameter name %s found" % name)
return param return param
def create_tmp_variable(self, dtype, stop_gradient=False): def create_variable_for_type_inference(self, dtype, stop_gradient=False):
"""Create a temporary variable that should be type inferred layer.
Note:
The default type will be set to LOD_TENSOR. However, when
the var is used as operator output, its type will be updated
based on operator's `VarTypeInference` implementation in
infer_var_type.
"""
return self.main_program.current_block().create_var( return self.main_program.current_block().create_var(
name=unique_name.generate(".".join([self.name, 'tmp'])), name=unique_name.generate(".".join([self.name, 'tmp'])),
dtype=dtype, dtype=dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=stop_gradient) stop_gradient=stop_gradient)
...@@ -388,7 +397,7 @@ class LayerHelper(object): ...@@ -388,7 +397,7 @@ class LayerHelper(object):
b = self.create_parameter( b = self.create_parameter(
attr=bias_attr, shape=size, dtype=input_var.dtype, is_bias=True) attr=bias_attr, shape=size, dtype=input_var.dtype, is_bias=True)
tmp = self.create_tmp_variable(dtype=input_var.dtype) tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
self.append_op( self.append_op(
type='elementwise_add', type='elementwise_add',
inputs={'X': [input_var], inputs={'X': [input_var],
...@@ -414,7 +423,7 @@ class LayerHelper(object): ...@@ -414,7 +423,7 @@ class LayerHelper(object):
tmp = input_var tmp = input_var
# NOTE(dzhwinter): some activation support inplace compution. # NOTE(dzhwinter): some activation support inplace compution.
if not core.IsInplace(act_type): if not core.IsInplace(act_type):
tmp = self.create_tmp_variable(dtype=input_var.dtype) tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
self.append_op( self.append_op(
type=act_type, type=act_type,
inputs={"X": [input_var]}, inputs={"X": [input_var]},
......
...@@ -80,8 +80,8 @@ def split_lod_tensor(input, mask, level=0): ...@@ -80,8 +80,8 @@ def split_lod_tensor(input, mask, level=0):
""" """
helper = LayerHelper('split_lod_tensor', **locals()) helper = LayerHelper('split_lod_tensor', **locals())
out_true = helper.create_tmp_variable(dtype=input.dtype) out_true = helper.create_variable_for_type_inference(dtype=input.dtype)
out_false = helper.create_tmp_variable(dtype=input.dtype) out_false = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op( helper.append_op(
type='split_lod_tensor', type='split_lod_tensor',
inputs={ inputs={
...@@ -131,7 +131,7 @@ def merge_lod_tensor(in_true, in_false, x, mask, level=0): ...@@ -131,7 +131,7 @@ def merge_lod_tensor(in_true, in_false, x, mask, level=0):
in_true=out_true, in_false=out_false, mask=y, x=x, level=level) in_true=out_true, in_false=out_false, mask=y, x=x, level=level)
""" """
helper = LayerHelper('merge_lod_tensor', **locals()) helper = LayerHelper('merge_lod_tensor', **locals())
out = helper.create_tmp_variable(dtype=in_true.dtype) out = helper.create_variable_for_type_inference(dtype=in_true.dtype)
helper.append_op( helper.append_op(
type='merge_lod_tensor', type='merge_lod_tensor',
inputs={'X': x, inputs={'X': x,
...@@ -524,7 +524,7 @@ class StaticRNN(object): ...@@ -524,7 +524,7 @@ class StaticRNN(object):
if not isinstance(o, Variable): if not isinstance(o, Variable):
raise TypeError("step output takes a Variable") raise TypeError("step output takes a Variable")
tmp_o = self.helper.create_tmp_variable(dtype=o.dtype) tmp_o = self.helper.create_variable_for_type_inference(dtype=o.dtype)
self.helper.append_op( self.helper.append_op(
type='rnn_memory_helper', type='rnn_memory_helper',
inputs={'X': [o]}, inputs={'X': [o]},
...@@ -606,7 +606,8 @@ class StaticRNN(object): ...@@ -606,7 +606,8 @@ class StaticRNN(object):
pre_memories.append(mem.pre_mem.name) pre_memories.append(mem.pre_mem.name)
mem_var = rnn_block.var(mem.mem.name) mem_var = rnn_block.var(mem.mem.name)
assert isinstance(mem_var, Variable) assert isinstance(mem_var, Variable)
new_mem = self.helper.create_tmp_variable(dtype=mem_var.dtype) new_mem = self.helper.create_variable_for_type_inference(
dtype=mem_var.dtype)
rnn_block.append_op( rnn_block.append_op(
type='rnn_memory_helper', type='rnn_memory_helper',
...@@ -813,7 +814,7 @@ def max_sequence_len(rank_table): ...@@ -813,7 +814,7 @@ def max_sequence_len(rank_table):
${out_comment}. ${out_comment}.
""" """
helper = LayerHelper("max_seqence_len", **locals()) helper = LayerHelper("max_seqence_len", **locals())
res = helper.create_tmp_variable(dtype="int64") res = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op( helper.append_op(
type="max_sequence_len", type="max_sequence_len",
inputs={"RankTable": rank_table}, inputs={"RankTable": rank_table},
...@@ -884,7 +885,7 @@ def array_to_lod_tensor(x, table): ...@@ -884,7 +885,7 @@ def array_to_lod_tensor(x, table):
lod_tensor = fluid.layers.array_to_lod_tensor(array, table) lod_tensor = fluid.layers.array_to_lod_tensor(array, table)
""" """
helper = LayerHelper("array_to_lod_tensor", **locals()) helper = LayerHelper("array_to_lod_tensor", **locals())
tmp = helper.create_tmp_variable(dtype=x.dtype) tmp = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type="array_to_lod_tensor", type="array_to_lod_tensor",
inputs={'X': x, inputs={'X': x,
...@@ -915,7 +916,7 @@ def increment(x, value=1.0, in_place=True): ...@@ -915,7 +916,7 @@ def increment(x, value=1.0, in_place=True):
""" """
helper = LayerHelper("increment", **locals()) helper = LayerHelper("increment", **locals())
if not in_place: if not in_place:
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
out = x out = x
helper.append_op( helper.append_op(
...@@ -1012,7 +1013,7 @@ def less_than(x, y, force_cpu=None, cond=None, **ignored): ...@@ -1012,7 +1013,7 @@ def less_than(x, y, force_cpu=None, cond=None, **ignored):
""" """
helper = LayerHelper("less_than", **locals()) helper = LayerHelper("less_than", **locals())
if cond is None: if cond is None:
cond = helper.create_tmp_variable(dtype='bool') cond = helper.create_variable_for_type_inference(dtype='bool')
cond.stop_gradient = True cond.stop_gradient = True
attrs = dict() attrs = dict()
...@@ -1051,7 +1052,7 @@ def equal(x, y, cond=None, **ignored): ...@@ -1051,7 +1052,7 @@ def equal(x, y, cond=None, **ignored):
""" """
helper = LayerHelper("equal", **locals()) helper = LayerHelper("equal", **locals())
if cond is None: if cond is None:
cond = helper.create_tmp_variable(dtype='bool') cond = helper.create_variable_for_type_inference(dtype='bool')
cond.stop_gradient = True cond.stop_gradient = True
helper.append_op( helper.append_op(
...@@ -1098,7 +1099,7 @@ def array_read(array, i): ...@@ -1098,7 +1099,7 @@ def array_read(array, i):
array, array,
Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY: Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
raise TypeError("array should be tensor array vairable") raise TypeError("array should be tensor array vairable")
out = helper.create_tmp_variable(dtype=array.dtype) out = helper.create_variable_for_type_inference(dtype=array.dtype)
helper.append_op( helper.append_op(
type='read_from_array', type='read_from_array',
inputs={'X': [array], inputs={'X': [array],
...@@ -1133,7 +1134,7 @@ def shrink_memory(x, i, table): ...@@ -1133,7 +1134,7 @@ def shrink_memory(x, i, table):
usage. usage.
""" """
helper = LayerHelper('shrink_memory', **locals()) helper = LayerHelper('shrink_memory', **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='shrink_rnn_memory', type='shrink_rnn_memory',
inputs={'X': [x], inputs={'X': [x],
...@@ -1170,7 +1171,7 @@ def array_length(array): ...@@ -1170,7 +1171,7 @@ def array_length(array):
""" """
helper = LayerHelper('array_length', **locals()) helper = LayerHelper('array_length', **locals())
tmp = helper.create_tmp_variable(dtype='int64') tmp = helper.create_variable_for_type_inference(dtype='int64')
tmp.stop_gradient = True tmp.stop_gradient = True
helper.append_op( helper.append_op(
type='lod_array_length', inputs={'X': [array]}, outputs={'Out': [tmp]}) type='lod_array_length', inputs={'X': [array]}, outputs={'Out': [tmp]})
...@@ -1590,7 +1591,7 @@ class DynamicRNN(object): ...@@ -1590,7 +1591,7 @@ class DynamicRNN(object):
self.mem_dict = dict() self.mem_dict = dict()
self.output_array = [] self.output_array = []
self.outputs = [] self.outputs = []
self.cond = self.helper.create_tmp_variable(dtype='bool') self.cond = self.helper.create_variable_for_type_inference(dtype='bool')
self.cond.stop_gradient = False self.cond.stop_gradient = False
self.while_op = While(self.cond) self.while_op = While(self.cond)
self.input_array = [] self.input_array = []
...@@ -1924,7 +1925,7 @@ def reorder_lod_tensor_by_rank(x, rank_table): ...@@ -1924,7 +1925,7 @@ def reorder_lod_tensor_by_rank(x, rank_table):
helper.is_instance('x', Variable) helper.is_instance('x', Variable)
helper.is_instance('rank_table', Variable) helper.is_instance('rank_table', Variable)
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='reorder_lod_tensor_by_rank', type='reorder_lod_tensor_by_rank',
inputs={'X': [x], inputs={'X': [x],
...@@ -1958,7 +1959,7 @@ def is_empty(x, cond=None, **ignored): ...@@ -1958,7 +1959,7 @@ def is_empty(x, cond=None, **ignored):
""" """
helper = LayerHelper("is_empty", **locals()) helper = LayerHelper("is_empty", **locals())
if cond is None: if cond is None:
cond = helper.create_tmp_variable(dtype='bool') cond = helper.create_variable_for_type_inference(dtype='bool')
cond.stop_gradient = True cond.stop_gradient = True
elif not isinstance(cond, Variable): elif not isinstance(cond, Variable):
raise TypeError("cond takes a variable") raise TypeError("cond takes a variable")
......
...@@ -116,8 +116,8 @@ def rpn_target_assign(bbox_pred, ...@@ -116,8 +116,8 @@ def rpn_target_assign(bbox_pred,
Returns: Returns:
tuple: tuple:
A tuple(predicted_scores, predicted_location, target_label, A tuple(predicted_scores, predicted_location, target_label,
target_bbox) is returned. The predicted_scores and target_bbox, bbox_inside_weight) is returned. The predicted_scores
predicted_location is the predicted result of the RPN. and predicted_location is the predicted result of the RPN.
The target_label and target_bbox is the ground truth, The target_label and target_bbox is the ground truth,
respectively. The predicted_location is a 2D Tensor with shape respectively. The predicted_location is a 2D Tensor with shape
[F, 4], and the shape of target_bbox is same as the shape of [F, 4], and the shape of target_bbox is same as the shape of
...@@ -126,6 +126,8 @@ def rpn_target_assign(bbox_pred, ...@@ -126,6 +126,8 @@ def rpn_target_assign(bbox_pred,
[F + B, 1], and the shape of target_label is same as the shape [F + B, 1], and the shape of target_label is same as the shape
of the predicted_scores, B is the number of the background of the predicted_scores, B is the number of the background
anchors, the F and B is depends on the input of this operator. anchors, the F and B is depends on the input of this operator.
Bbox_inside_weight represents whether the predicted loc is fake_fg
or not and the shape is [F, 4].
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -138,7 +140,7 @@ def rpn_target_assign(bbox_pred, ...@@ -138,7 +140,7 @@ def rpn_target_assign(bbox_pred,
append_batch_size=False, dtype='float32') append_batch_size=False, dtype='float32')
gt_boxes = layers.data(name='gt_boxes', shape=[10, 4], gt_boxes = layers.data(name='gt_boxes', shape=[10, 4],
append_batch_size=False, dtype='float32') append_batch_size=False, dtype='float32')
loc_pred, score_pred, loc_target, score_target = loc_pred, score_pred, loc_target, score_target, bbox_inside_weight =
fluid.layers.rpn_target_assign(bbox_pred=bbox_pred, fluid.layers.rpn_target_assign(bbox_pred=bbox_pred,
cls_logits=cls_logits, cls_logits=cls_logits,
anchor_box=anchor_box, anchor_box=anchor_box,
...@@ -147,10 +149,13 @@ def rpn_target_assign(bbox_pred, ...@@ -147,10 +149,13 @@ def rpn_target_assign(bbox_pred,
helper = LayerHelper('rpn_target_assign', **locals()) helper = LayerHelper('rpn_target_assign', **locals())
# Assign target label to anchors # Assign target label to anchors
loc_index = helper.create_tmp_variable(dtype='int32') loc_index = helper.create_variable_for_type_inference(dtype='int32')
score_index = helper.create_tmp_variable(dtype='int32') score_index = helper.create_variable_for_type_inference(dtype='int32')
target_label = helper.create_tmp_variable(dtype='int32') target_label = helper.create_variable_for_type_inference(dtype='int32')
target_bbox = helper.create_tmp_variable(dtype=anchor_box.dtype) target_bbox = helper.create_variable_for_type_inference(
dtype=anchor_box.dtype)
bbox_inside_weight = helper.create_variable_for_type_inference(
dtype=anchor_box.dtype)
helper.append_op( helper.append_op(
type="rpn_target_assign", type="rpn_target_assign",
inputs={ inputs={
...@@ -163,7 +168,8 @@ def rpn_target_assign(bbox_pred, ...@@ -163,7 +168,8 @@ def rpn_target_assign(bbox_pred,
'LocationIndex': loc_index, 'LocationIndex': loc_index,
'ScoreIndex': score_index, 'ScoreIndex': score_index,
'TargetLabel': target_label, 'TargetLabel': target_label,
'TargetBBox': target_bbox 'TargetBBox': target_bbox,
'BBoxInsideWeight': bbox_inside_weight
}, },
attrs={ attrs={
'rpn_batch_size_per_im': rpn_batch_size_per_im, 'rpn_batch_size_per_im': rpn_batch_size_per_im,
...@@ -178,13 +184,14 @@ def rpn_target_assign(bbox_pred, ...@@ -178,13 +184,14 @@ def rpn_target_assign(bbox_pred,
score_index.stop_gradient = True score_index.stop_gradient = True
target_label.stop_gradient = True target_label.stop_gradient = True
target_bbox.stop_gradient = True target_bbox.stop_gradient = True
bbox_inside_weight.stop_gradient = True
cls_logits = nn.reshape(x=cls_logits, shape=(-1, 1)) cls_logits = nn.reshape(x=cls_logits, shape=(-1, 1))
bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4)) bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4))
predicted_cls_logits = nn.gather(cls_logits, score_index) predicted_cls_logits = nn.gather(cls_logits, score_index)
predicted_bbox_pred = nn.gather(bbox_pred, loc_index) predicted_bbox_pred = nn.gather(bbox_pred, loc_index)
return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight
def detection_output(loc, def detection_output(loc,
...@@ -282,7 +289,8 @@ def detection_output(loc, ...@@ -282,7 +289,8 @@ def detection_output(loc,
scores = nn.reshape(x=scores, shape=compile_shape, actual_shape=run_shape) scores = nn.reshape(x=scores, shape=compile_shape, actual_shape=run_shape)
scores = nn.transpose(scores, perm=[0, 2, 1]) scores = nn.transpose(scores, perm=[0, 2, 1])
scores.stop_gradient = True scores.stop_gradient = True
nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype) nmsed_outs = helper.create_variable_for_type_inference(
dtype=decoded_box.dtype)
helper.append_op( helper.append_op(
type="multiclass_nms", type="multiclass_nms",
inputs={'Scores': scores, inputs={'Scores': scores,
...@@ -314,7 +322,7 @@ def iou_similarity(x, y, name=None): ...@@ -314,7 +322,7 @@ def iou_similarity(x, y, name=None):
""" """
helper = LayerHelper("iou_similarity", **locals()) helper = LayerHelper("iou_similarity", **locals())
if name is None: if name is None:
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
out = helper.create_variable( out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False) name=name, dtype=x.dtype, persistable=False)
...@@ -351,7 +359,8 @@ def box_coder(prior_box, ...@@ -351,7 +359,8 @@ def box_coder(prior_box,
helper = LayerHelper("box_coder", **locals()) helper = LayerHelper("box_coder", **locals())
if name is None: if name is None:
output_box = helper.create_tmp_variable(dtype=prior_box.dtype) output_box = helper.create_variable_for_type_inference(
dtype=prior_box.dtype)
else: else:
output_box = helper.create_variable( output_box = helper.create_variable(
name=name, dtype=prior_box.dtype, persistable=False) name=name, dtype=prior_box.dtype, persistable=False)
...@@ -382,7 +391,7 @@ def polygon_box_transform(input, name=None): ...@@ -382,7 +391,7 @@ def polygon_box_transform(input, name=None):
""" """
helper = LayerHelper("polygon_box_transform", **locals()) helper = LayerHelper("polygon_box_transform", **locals())
if name is None: if name is None:
output = helper.create_tmp_variable(dtype=input.dtype) output = helper.create_variable_for_type_inference(dtype=input.dtype)
else: else:
output = helper.create_variable( output = helper.create_variable(
name=name, dtype=prior_box.input, persistable=False) name=name, dtype=prior_box.input, persistable=False)
...@@ -450,7 +459,7 @@ def detection_map(detect_res, ...@@ -450,7 +459,7 @@ def detection_map(detect_res,
helper = LayerHelper("detection_map", **locals()) helper = LayerHelper("detection_map", **locals())
def __create_var(type): def __create_var(type):
return helper.create_tmp_variable(dtype=type) return helper.create_variable_for_type_inference(dtype=type)
map_out = __create_var('float32') map_out = __create_var('float32')
accum_pos_count_out = out_states[0] if out_states else __create_var('int32') accum_pos_count_out = out_states[0] if out_states else __create_var('int32')
...@@ -557,8 +566,9 @@ def bipartite_match(dist_matrix, ...@@ -557,8 +566,9 @@ def bipartite_match(dist_matrix,
>>> matched_indices, matched_dist = fluid.layers.bipartite_match(iou) >>> matched_indices, matched_dist = fluid.layers.bipartite_match(iou)
""" """
helper = LayerHelper('bipartite_match', **locals()) helper = LayerHelper('bipartite_match', **locals())
match_indices = helper.create_tmp_variable(dtype='int32') match_indices = helper.create_variable_for_type_inference(dtype='int32')
match_distance = helper.create_tmp_variable(dtype=dist_matrix.dtype) match_distance = helper.create_variable_for_type_inference(
dtype=dist_matrix.dtype)
helper.append_op( helper.append_op(
type='bipartite_match', type='bipartite_match',
inputs={'DistMat': dist_matrix}, inputs={'DistMat': dist_matrix},
...@@ -644,8 +654,8 @@ def target_assign(input, ...@@ -644,8 +654,8 @@ def target_assign(input,
gt, matched_indices, mismatch_value=0) gt, matched_indices, mismatch_value=0)
""" """
helper = LayerHelper('target_assign', **locals()) helper = LayerHelper('target_assign', **locals())
out = helper.create_tmp_variable(dtype=input.dtype) out = helper.create_variable_for_type_inference(dtype=input.dtype)
out_weight = helper.create_tmp_variable(dtype='float32') out_weight = helper.create_variable_for_type_inference(dtype='float32')
helper.append_op( helper.append_op(
type='target_assign', type='target_assign',
inputs={ inputs={
...@@ -816,9 +826,10 @@ def ssd_loss(location, ...@@ -816,9 +826,10 @@ def ssd_loss(location,
conf_loss = nn.reshape( conf_loss = nn.reshape(
x=conf_loss, shape=(num, num_prior), actual_shape=actual_shape) x=conf_loss, shape=(num, num_prior), actual_shape=actual_shape)
conf_loss.stop_gradient = True conf_loss.stop_gradient = True
neg_indices = helper.create_tmp_variable(dtype='int32') neg_indices = helper.create_variable_for_type_inference(dtype='int32')
dtype = matched_indices.dtype dtype = matched_indices.dtype
updated_matched_indices = helper.create_tmp_variable(dtype=dtype) updated_matched_indices = helper.create_variable_for_type_inference(
dtype=dtype)
helper.append_op( helper.append_op(
type='mine_hard_examples', type='mine_hard_examples',
inputs={ inputs={
...@@ -998,8 +1009,8 @@ def prior_box(input, ...@@ -998,8 +1009,8 @@ def prior_box(input,
max_sizes = [max_sizes] max_sizes = [max_sizes]
attrs['max_sizes'] = max_sizes attrs['max_sizes'] = max_sizes
box = helper.create_tmp_variable(dtype) box = helper.create_variable_for_type_inference(dtype)
var = helper.create_tmp_variable(dtype) var = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="prior_box", type="prior_box",
inputs={"Input": input, inputs={"Input": input,
...@@ -1337,8 +1348,8 @@ def anchor_generator(input, ...@@ -1337,8 +1348,8 @@ def anchor_generator(input,
'offset': offset 'offset': offset
} }
anchor = helper.create_tmp_variable(dtype) anchor = helper.create_variable_for_type_inference(dtype)
var = helper.create_tmp_variable(dtype) var = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="anchor_generator", type="anchor_generator",
inputs={"Input": input}, inputs={"Input": input},
...@@ -1384,7 +1395,7 @@ def roi_perspective_transform(input, ...@@ -1384,7 +1395,7 @@ def roi_perspective_transform(input,
""" """
helper = LayerHelper('roi_perspective_transform', **locals()) helper = LayerHelper('roi_perspective_transform', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="roi_perspective_transform", type="roi_perspective_transform",
inputs={"X": input, inputs={"X": input,
...@@ -1418,11 +1429,15 @@ def generate_proposal_labels(rpn_rois, ...@@ -1418,11 +1429,15 @@ def generate_proposal_labels(rpn_rois,
helper = LayerHelper('generate_proposal_labels', **locals()) helper = LayerHelper('generate_proposal_labels', **locals())
rois = helper.create_tmp_variable(dtype=rpn_rois.dtype) rois = helper.create_variable_for_type_inference(dtype=rpn_rois.dtype)
labels_int32 = helper.create_tmp_variable(dtype=gt_classes.dtype) labels_int32 = helper.create_variable_for_type_inference(
bbox_targets = helper.create_tmp_variable(dtype=rpn_rois.dtype) dtype=gt_classes.dtype)
bbox_inside_weights = helper.create_tmp_variable(dtype=rpn_rois.dtype) bbox_targets = helper.create_variable_for_type_inference(
bbox_outside_weights = helper.create_tmp_variable(dtype=rpn_rois.dtype) dtype=rpn_rois.dtype)
bbox_inside_weights = helper.create_variable_for_type_inference(
dtype=rpn_rois.dtype)
bbox_outside_weights = helper.create_variable_for_type_inference(
dtype=rpn_rois.dtype)
helper.append_op( helper.append_op(
type="generate_proposal_labels", type="generate_proposal_labels",
...@@ -1504,8 +1519,10 @@ def generate_proposals(scores, ...@@ -1504,8 +1519,10 @@ def generate_proposals(scores,
""" """
helper = LayerHelper('generate_proposals', **locals()) helper = LayerHelper('generate_proposals', **locals())
rpn_rois = helper.create_tmp_variable(dtype=bbox_deltas.dtype) rpn_rois = helper.create_variable_for_type_inference(
rpn_roi_probs = helper.create_tmp_variable(dtype=scores.dtype) dtype=bbox_deltas.dtype)
rpn_roi_probs = helper.create_variable_for_type_inference(
dtype=scores.dtype)
helper.append_op( helper.append_op(
type="generate_proposals", type="generate_proposals",
inputs={ inputs={
......
...@@ -954,7 +954,7 @@ def read_file(reader): ...@@ -954,7 +954,7 @@ def read_file(reader):
""" """
helper = LayerHelper('read_file') helper = LayerHelper('read_file')
out = [ out = [
helper.create_tmp_variable( helper.create_variable_for_type_inference(
stop_gradient=True, dtype='float32') stop_gradient=True, dtype='float32')
for _ in range(len(reader.desc.shapes())) for _ in range(len(reader.desc.shapes()))
] ]
......
...@@ -202,10 +202,12 @@ def generate_layer_fn(op_type): ...@@ -202,10 +202,12 @@ def generate_layer_fn(op_type):
out_var = out[0] if (isinstance(out, list) or out_var = out[0] if (isinstance(out, list) or
isinstance(out, tuple)) else out isinstance(out, tuple)) else out
else: else:
out_var = helper.create_tmp_variable(dtype=dtype) out_var = helper.create_variable_for_type_inference(dtype=dtype)
outputs[o_name] = [out_var] outputs[o_name] = [out_var]
for name in intermediate_output_names: for name in intermediate_output_names:
outputs[name] = [helper.create_tmp_variable(dtype=dtype)] outputs[name] = [
helper.create_variable_for_type_inference(dtype=dtype)
]
helper.append_op( helper.append_op(
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs) type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
return helper.append_activation(out_var) return helper.append_activation(out_var)
...@@ -229,7 +231,7 @@ def generate_layer_fn_noattr(op_type): ...@@ -229,7 +231,7 @@ def generate_layer_fn_noattr(op_type):
def func(x, name=None): def func(x, name=None):
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
output = helper.create_tmp_variable(dtype=x.dtype) output = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type=op_type, inputs={"X": x}, outputs={"Out": output}) helper.append_op(type=op_type, inputs={"X": x}, outputs={"Out": output})
return output return output
......
...@@ -58,11 +58,11 @@ def accuracy(input, label, k=1, correct=None, total=None): ...@@ -58,11 +58,11 @@ def accuracy(input, label, k=1, correct=None, total=None):
""" """
helper = LayerHelper("accuracy", **locals()) helper = LayerHelper("accuracy", **locals())
topk_out, topk_indices = nn.topk(input, k=k) topk_out, topk_indices = nn.topk(input, k=k)
acc_out = helper.create_tmp_variable(dtype="float32") acc_out = helper.create_variable_for_type_inference(dtype="float32")
if correct is None: if correct is None:
correct = helper.create_tmp_variable(dtype="int64") correct = helper.create_variable_for_type_inference(dtype="int64")
if total is None: if total is None:
total = helper.create_tmp_variable(dtype="int64") total = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op( helper.append_op(
type="accuracy", type="accuracy",
inputs={ inputs={
...@@ -124,8 +124,8 @@ def auc(input, ...@@ -124,8 +124,8 @@ def auc(input,
auc_out=fluid.layers.auc(input=prediction, label=label) auc_out=fluid.layers.auc(input=prediction, label=label)
""" """
helper = LayerHelper("auc", **locals()) helper = LayerHelper("auc", **locals())
auc_out = helper.create_tmp_variable(dtype="float64") auc_out = helper.create_variable_for_type_inference(dtype="float64")
batch_auc_out = helper.create_tmp_variable(dtype="float64") batch_auc_out = helper.create_variable_for_type_inference(dtype="float64")
# make tp, tn, fp, fn persistable, so that can accumulate all batches. # make tp, tn, fp, fn persistable, so that can accumulate all batches.
# for batch auc # for batch auc
......
...@@ -242,7 +242,7 @@ def fc(input, ...@@ -242,7 +242,7 @@ def fc(input,
w = helper.create_parameter( w = helper.create_parameter(
attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False) attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False)
tmp = helper.create_tmp_variable(dtype) tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="mul", type="mul",
inputs={"X": input_var, inputs={"X": input_var,
...@@ -255,7 +255,7 @@ def fc(input, ...@@ -255,7 +255,7 @@ def fc(input,
if len(mul_results) == 1: if len(mul_results) == 1:
pre_bias = mul_results[0] pre_bias = mul_results[0]
else: else:
pre_bias = helper.create_tmp_variable(dtype) pre_bias = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="sum", type="sum",
inputs={"X": mul_results}, inputs={"X": mul_results},
...@@ -314,7 +314,7 @@ def embedding(input, ...@@ -314,7 +314,7 @@ def embedding(input,
helper = LayerHelper('embedding', **locals()) helper = LayerHelper('embedding', **locals())
w = helper.create_parameter( w = helper.create_parameter(
attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False) attr=helper.param_attr, shape=size, dtype=dtype, is_bias=False)
tmp = helper.create_tmp_variable(dtype) tmp = helper.create_variable_for_type_inference(dtype)
padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
size[0] + padding_idx) size[0] + padding_idx)
helper.append_op( helper.append_op(
...@@ -418,10 +418,10 @@ def dynamic_lstm(input, ...@@ -418,10 +418,10 @@ def dynamic_lstm(input,
bias = helper.create_parameter( bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True) attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True)
hidden = helper.create_tmp_variable(dtype) hidden = helper.create_variable_for_type_inference(dtype)
cell = helper.create_tmp_variable(dtype) cell = helper.create_variable_for_type_inference(dtype)
batch_gate = helper.create_tmp_variable(dtype) batch_gate = helper.create_variable_for_type_inference(dtype)
batch_cell_pre_act = helper.create_tmp_variable(dtype) batch_cell_pre_act = helper.create_variable_for_type_inference(dtype)
inputs = {'Input': input, 'Weight': weight, 'Bias': bias} inputs = {'Input': input, 'Weight': weight, 'Bias': bias}
batch_size = input.shape[0] batch_size = input.shape[0]
if h_0: if h_0:
...@@ -621,12 +621,12 @@ def dynamic_lstmp(input, ...@@ -621,12 +621,12 @@ def dynamic_lstmp(input,
bias = helper.create_parameter( bias = helper.create_parameter(
attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True) attr=helper.bias_attr, shape=bias_size, dtype=dtype, is_bias=True)
projection = helper.create_tmp_variable(dtype) projection = helper.create_variable_for_type_inference(dtype)
cell = helper.create_tmp_variable(dtype) cell = helper.create_variable_for_type_inference(dtype)
ordered_proj0 = helper.create_tmp_variable(dtype) ordered_proj0 = helper.create_variable_for_type_inference(dtype)
batch_hidden = helper.create_tmp_variable(dtype) batch_hidden = helper.create_variable_for_type_inference(dtype)
batch_gate = helper.create_tmp_variable(dtype) batch_gate = helper.create_variable_for_type_inference(dtype)
batch_cell_pre_act = helper.create_tmp_variable(dtype) batch_cell_pre_act = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='lstmp', type='lstmp',
...@@ -751,10 +751,10 @@ def dynamic_gru(input, ...@@ -751,10 +751,10 @@ def dynamic_gru(input,
), 'The shape of h0 should be(batch_size, %d)' % size ), 'The shape of h0 should be(batch_size, %d)' % size
inputs['H0'] = h_0 inputs['H0'] = h_0
hidden = helper.create_tmp_variable(dtype) hidden = helper.create_variable_for_type_inference(dtype)
batch_gate = helper.create_tmp_variable(dtype) batch_gate = helper.create_variable_for_type_inference(dtype)
batch_reset_hidden_prev = helper.create_tmp_variable(dtype) batch_reset_hidden_prev = helper.create_variable_for_type_inference(dtype)
batch_hidden = helper.create_tmp_variable(dtype) batch_hidden = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='gru', type='gru',
...@@ -844,9 +844,9 @@ def gru_unit(input, ...@@ -844,9 +844,9 @@ def gru_unit(input,
weight = helper.create_parameter( weight = helper.create_parameter(
attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype) attr=helper.param_attr, shape=[size, 3 * size], dtype=dtype)
gate = helper.create_tmp_variable(dtype) gate = helper.create_variable_for_type_inference(dtype)
reset_hidden_pre = helper.create_tmp_variable(dtype) reset_hidden_pre = helper.create_variable_for_type_inference(dtype)
updated_hidden = helper.create_tmp_variable(dtype) updated_hidden = helper.create_variable_for_type_inference(dtype)
inputs = {'Input': input, 'HiddenPrev': hidden, 'Weight': weight} inputs = {'Input': input, 'HiddenPrev': hidden, 'Weight': weight}
# create bias # create bias
if helper.bias_attr: if helper.bias_attr:
...@@ -896,10 +896,14 @@ def linear_chain_crf(input, label, param_attr=None): ...@@ -896,10 +896,14 @@ def linear_chain_crf(input, label, param_attr=None):
attr=helper.param_attr, attr=helper.param_attr,
shape=[size + 2, size], shape=[size + 2, size],
dtype=helper.input_dtype()) dtype=helper.input_dtype())
alpha = helper.create_tmp_variable(dtype=helper.input_dtype()) alpha = helper.create_variable_for_type_inference(
emission_exps = helper.create_tmp_variable(dtype=helper.input_dtype()) dtype=helper.input_dtype())
transition_exps = helper.create_tmp_variable(dtype=helper.input_dtype()) emission_exps = helper.create_variable_for_type_inference(
log_likelihood = helper.create_tmp_variable(dtype=helper.input_dtype()) dtype=helper.input_dtype())
transition_exps = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
log_likelihood = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
helper.append_op( helper.append_op(
type='linear_chain_crf', type='linear_chain_crf',
inputs={"Emission": [input], inputs={"Emission": [input],
...@@ -938,7 +942,8 @@ def crf_decoding(input, param_attr, label=None): ...@@ -938,7 +942,8 @@ def crf_decoding(input, param_attr, label=None):
""" """
helper = LayerHelper('crf_decoding', **locals()) helper = LayerHelper('crf_decoding', **locals())
transition = helper.get_parameter(param_attr.name) transition = helper.get_parameter(param_attr.name)
viterbi_path = helper.create_tmp_variable(dtype=helper.input_dtype()) viterbi_path = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
helper.append_op( helper.append_op(
type='crf_decoding', type='crf_decoding',
inputs={"Emission": [input], inputs={"Emission": [input],
...@@ -962,9 +967,9 @@ def cos_sim(X, Y): ...@@ -962,9 +967,9 @@ def cos_sim(X, Y):
Variable: the output of cosine(X, Y). Variable: the output of cosine(X, Y).
""" """
helper = LayerHelper('cos_sim', **locals()) helper = LayerHelper('cos_sim', **locals())
out = helper.create_tmp_variable(dtype=X.dtype) out = helper.create_variable_for_type_inference(dtype=X.dtype)
xnorm = helper.create_tmp_variable(dtype=X.dtype) xnorm = helper.create_variable_for_type_inference(dtype=X.dtype)
ynorm = helper.create_tmp_variable(dtype=X.dtype) ynorm = helper.create_variable_for_type_inference(dtype=X.dtype)
helper.append_op( helper.append_op(
type='cos_sim', type='cos_sim',
inputs={'X': [X], inputs={'X': [X],
...@@ -1008,8 +1013,9 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None): ...@@ -1008,8 +1013,9 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
""" """
helper = LayerHelper('dropout', **locals()) helper = LayerHelper('dropout', **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
mask = helper.create_tmp_variable(dtype=x.dtype, stop_gradient=True) mask = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
if (seed is None or seed == 0) and helper.main_program.random_seed != 0: if (seed is None or seed == 0) and helper.main_program.random_seed != 0:
seed = helper.main_program.random_seed seed = helper.main_program.random_seed
...@@ -1094,7 +1100,7 @@ def cross_entropy(input, label, soft_label=False, ignore_index=-100): ...@@ -1094,7 +1100,7 @@ def cross_entropy(input, label, soft_label=False, ignore_index=-100):
cost = fluid.layers.cross_entropy(input=predict, label=label) cost = fluid.layers.cross_entropy(input=predict, label=label)
""" """
helper = LayerHelper('cross_entropy', **locals()) helper = LayerHelper('cross_entropy', **locals())
out = helper.create_tmp_variable(dtype=input.dtype) out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op( helper.append_op(
type='cross_entropy', type='cross_entropy',
inputs={'X': [input], inputs={'X': [input],
...@@ -1141,14 +1147,14 @@ def square_error_cost(input, label): ...@@ -1141,14 +1147,14 @@ def square_error_cost(input, label):
""" """
helper = LayerHelper('square_error_cost', **locals()) helper = LayerHelper('square_error_cost', **locals())
minus_out = helper.create_tmp_variable(dtype=input.dtype) minus_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op( helper.append_op(
type='elementwise_sub', type='elementwise_sub',
inputs={'X': [input], inputs={'X': [input],
'Y': [label]}, 'Y': [label]},
outputs={'Out': [minus_out]}) outputs={'Out': [minus_out]})
square_out = helper.create_tmp_variable(dtype=input.dtype) square_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op( helper.append_op(
type='square', inputs={'X': [minus_out]}, type='square', inputs={'X': [minus_out]},
outputs={'Out': [square_out]}) outputs={'Out': [square_out]})
...@@ -1254,12 +1260,13 @@ def chunk_eval(input, ...@@ -1254,12 +1260,13 @@ def chunk_eval(input,
helper = LayerHelper("chunk_eval", **locals()) helper = LayerHelper("chunk_eval", **locals())
# prepare output # prepare output
precision = helper.create_tmp_variable(dtype="float32") precision = helper.create_variable_for_type_inference(dtype="float32")
recall = helper.create_tmp_variable(dtype="float32") recall = helper.create_variable_for_type_inference(dtype="float32")
f1_score = helper.create_tmp_variable(dtype="float32") f1_score = helper.create_variable_for_type_inference(dtype="float32")
num_infer_chunks = helper.create_tmp_variable(dtype="int64") num_infer_chunks = helper.create_variable_for_type_inference(dtype="int64")
num_label_chunks = helper.create_tmp_variable(dtype="int64") num_label_chunks = helper.create_variable_for_type_inference(dtype="int64")
num_correct_chunks = helper.create_tmp_variable(dtype="int64") num_correct_chunks = helper.create_variable_for_type_inference(
dtype="int64")
helper.append_op( helper.append_op(
type="chunk_eval", type="chunk_eval",
...@@ -1326,7 +1333,7 @@ def sequence_conv(input, ...@@ -1326,7 +1333,7 @@ def sequence_conv(input,
filter_shape = [filter_size * input.shape[1], num_filters] filter_shape = [filter_size * input.shape[1], num_filters]
filter_param = helper.create_parameter( filter_param = helper.create_parameter(
attr=helper.param_attr, shape=filter_shape, dtype=dtype) attr=helper.param_attr, shape=filter_shape, dtype=dtype)
pre_bias = helper.create_tmp_variable(dtype) pre_bias = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='sequence_conv', type='sequence_conv',
...@@ -1382,7 +1389,7 @@ def sequence_softmax(input, use_cudnn=False, name=None): ...@@ -1382,7 +1389,7 @@ def sequence_softmax(input, use_cudnn=False, name=None):
""" """
helper = LayerHelper('sequence_softmax', **locals()) helper = LayerHelper('sequence_softmax', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
softmax_out = helper.create_tmp_variable(dtype) softmax_out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="sequence_softmax", type="sequence_softmax",
inputs={"X": input}, inputs={"X": input},
...@@ -1436,7 +1443,7 @@ def softmax(input, use_cudnn=True, name=None): ...@@ -1436,7 +1443,7 @@ def softmax(input, use_cudnn=True, name=None):
""" """
helper = LayerHelper('softmax', **locals()) helper = LayerHelper('softmax', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
softmax_out = helper.create_tmp_variable(dtype) softmax_out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="softmax", type="softmax",
inputs={"X": input}, inputs={"X": input},
...@@ -1599,7 +1606,7 @@ def conv2d(input, ...@@ -1599,7 +1606,7 @@ def conv2d(input,
dtype=dtype, dtype=dtype,
default_initializer=_get_default_param_initializer()) default_initializer=_get_default_param_initializer())
pre_bias = helper.create_tmp_variable(dtype) pre_bias = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type=l_type, type=l_type,
...@@ -1770,7 +1777,7 @@ def conv3d(input, ...@@ -1770,7 +1777,7 @@ def conv3d(input,
dtype=dtype, dtype=dtype,
default_initializer=_get_default_param_initializer()) default_initializer=_get_default_param_initializer())
pre_bias = helper.create_tmp_variable(dtype) pre_bias = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type=l_type, type=l_type,
...@@ -1849,8 +1856,8 @@ def sequence_pool(input, pool_type): ...@@ -1849,8 +1856,8 @@ def sequence_pool(input, pool_type):
""" """
helper = LayerHelper('sequence_pool', **locals()) helper = LayerHelper('sequence_pool', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
pool_out = helper.create_tmp_variable(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
max_index = helper.create_tmp_variable(dtype) max_index = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="sequence_pool", type="sequence_pool",
...@@ -1886,7 +1893,7 @@ def sequence_concat(input, name=None): ...@@ -1886,7 +1893,7 @@ def sequence_concat(input, name=None):
out = fluid.layers.sequence_concat(input=[seq1, seq2, seq3]) out = fluid.layers.sequence_concat(input=[seq1, seq2, seq3])
""" """
helper = LayerHelper('sequence_concat', **locals()) helper = LayerHelper('sequence_concat', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op( helper.append_op(
type='sequence_concat', inputs={'X': input}, outputs={'Out': [out]}) type='sequence_concat', inputs={'X': input}, outputs={'Out': [out]})
return out return out
...@@ -2013,7 +2020,7 @@ def sequence_slice(input, offset, length, name=None): ...@@ -2013,7 +2020,7 @@ def sequence_slice(input, offset, length, name=None):
""" """
helper = LayerHelper("sequence_slice", **locals()) helper = LayerHelper("sequence_slice", **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
offset.stop_gradient = True offset.stop_gradient = True
length.stop_gradient = True length.stop_gradient = True
...@@ -2099,7 +2106,7 @@ def pool2d(input, ...@@ -2099,7 +2106,7 @@ def pool2d(input,
helper = LayerHelper(l_type, **locals()) helper = LayerHelper(l_type, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
pool_out = helper.create_tmp_variable(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type=l_type, type=l_type,
...@@ -2167,7 +2174,7 @@ def pool3d(input, ...@@ -2167,7 +2174,7 @@ def pool3d(input,
l_type = "pool3d" l_type = "pool3d"
helper = LayerHelper(l_type, **locals()) helper = LayerHelper(l_type, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
pool_out = helper.create_tmp_variable(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type=l_type, type=l_type,
...@@ -2310,10 +2317,13 @@ def batch_norm(input, ...@@ -2310,10 +2317,13 @@ def batch_norm(input,
mean_out = mean mean_out = mean
# variance and variance out share the same memory # variance and variance out share the same memory
variance_out = variance variance_out = variance
saved_mean = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) saved_mean = helper.create_variable_for_type_inference(
saved_variance = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) dtype=dtype, stop_gradient=True)
saved_variance = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
batch_norm_out = input if in_place else helper.create_tmp_variable(dtype) batch_norm_out = input if in_place else helper.create_variable_for_type_inference(
dtype)
helper.append_op( helper.append_op(
type="batch_norm", type="batch_norm",
...@@ -2430,9 +2440,11 @@ def layer_norm(input, ...@@ -2430,9 +2440,11 @@ def layer_norm(input,
inputs['Bias'] = bias inputs['Bias'] = bias
# create output # create output
mean_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) mean_out = helper.create_variable_for_type_inference(
variance_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) dtype=dtype, stop_gradient=True)
layer_norm_out = helper.create_tmp_variable(dtype) variance_out = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
layer_norm_out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="layer_norm", type="layer_norm",
...@@ -2619,7 +2631,7 @@ def conv2d_transpose(input, ...@@ -2619,7 +2631,7 @@ def conv2d_transpose(input,
img_filter = helper.create_parameter( img_filter = helper.create_parameter(
dtype=input.dtype, shape=filter_shape, attr=helper.param_attr) dtype=input.dtype, shape=filter_shape, attr=helper.param_attr)
pre_bias = helper.create_tmp_variable(dtype=input.dtype) pre_bias = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op( helper.append_op(
type=op_type, type=op_type,
inputs={'Input': [input], inputs={'Input': [input],
...@@ -2797,7 +2809,7 @@ def conv3d_transpose(input, ...@@ -2797,7 +2809,7 @@ def conv3d_transpose(input,
img_filter = helper.create_parameter( img_filter = helper.create_parameter(
dtype=input.dtype, shape=filter_shape, attr=helper.param_attr) dtype=input.dtype, shape=filter_shape, attr=helper.param_attr)
pre_bias = helper.create_tmp_variable(dtype=input.dtype) pre_bias = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op( helper.append_op(
type=l_type, type=l_type,
inputs={'Input': [input], inputs={'Input': [input],
...@@ -2876,7 +2888,7 @@ def sequence_expand(x, y, ref_level=-1, name=None): ...@@ -2876,7 +2888,7 @@ def sequence_expand(x, y, ref_level=-1, name=None):
""" """
helper = LayerHelper('sequence_expand', input=x, **locals()) helper = LayerHelper('sequence_expand', input=x, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
tmp = helper.create_tmp_variable(dtype) tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='sequence_expand', type='sequence_expand',
inputs={'X': x, inputs={'X': x,
...@@ -2942,7 +2954,7 @@ def sequence_expand_as(x, y, name=None): ...@@ -2942,7 +2954,7 @@ def sequence_expand_as(x, y, name=None):
""" """
helper = LayerHelper('sequence_expand_as', input=x, **locals()) helper = LayerHelper('sequence_expand_as', input=x, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
tmp = helper.create_tmp_variable(dtype) tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='sequence_expand_as', type='sequence_expand_as',
inputs={'X': x, inputs={'X': x,
...@@ -2987,8 +2999,8 @@ def sequence_pad(x, pad_value, maxlen=None, name=None): ...@@ -2987,8 +2999,8 @@ def sequence_pad(x, pad_value, maxlen=None, name=None):
helper = LayerHelper('sequence_pad', input=x, **locals()) helper = LayerHelper('sequence_pad', input=x, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
length = helper.create_tmp_variable(dtype) length = helper.create_variable_for_type_inference(dtype)
pad_value.stop_gradient = True pad_value.stop_gradient = True
length.stop_gradient = True length.stop_gradient = True
...@@ -3053,7 +3065,7 @@ def sequence_unpad(x, length, name=None): ...@@ -3053,7 +3065,7 @@ def sequence_unpad(x, length, name=None):
helper = LayerHelper('sequence_unpad', input=x, **locals()) helper = LayerHelper('sequence_unpad', input=x, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
length.stop_gradient = True length.stop_gradient = True
...@@ -3152,8 +3164,9 @@ def beam_search(pre_ids, ...@@ -3152,8 +3164,9 @@ def beam_search(pre_ids,
score_type = scores.dtype score_type = scores.dtype
id_type = ids.dtype id_type = ids.dtype
selected_scores = helper.create_tmp_variable(dtype=score_type) selected_scores = helper.create_variable_for_type_inference(
selected_ids = helper.create_tmp_variable(dtype=id_type) dtype=score_type)
selected_ids = helper.create_variable_for_type_inference(dtype=id_type)
helper.append_op( helper.append_op(
type='beam_search', type='beam_search',
...@@ -3210,8 +3223,8 @@ def beam_search_decode(ids, scores, beam_size, end_id, name=None): ...@@ -3210,8 +3223,8 @@ def beam_search_decode(ids, scores, beam_size, end_id, name=None):
ids, scores, beam_size=5, end_id=0) ids, scores, beam_size=5, end_id=0)
""" """
helper = LayerHelper('beam_search_decode', **locals()) helper = LayerHelper('beam_search_decode', **locals())
sentence_ids = helper.create_tmp_variable(dtype=ids.dtype) sentence_ids = helper.create_variable_for_type_inference(dtype=ids.dtype)
sentence_scores = helper.create_tmp_variable(dtype=ids.dtype) sentence_scores = helper.create_variable_for_type_inference(dtype=ids.dtype)
helper.append_op( helper.append_op(
type="beam_search_decode", type="beam_search_decode",
...@@ -3341,8 +3354,8 @@ def lstm_unit(x_t, ...@@ -3341,8 +3354,8 @@ def lstm_unit(x_t,
param_attr=param_attr, param_attr=param_attr,
bias_attr=bias_attr) bias_attr=bias_attr)
dtype = x_t.dtype dtype = x_t.dtype
c = helper.create_tmp_variable(dtype) c = helper.create_variable_for_type_inference(dtype)
h = helper.create_tmp_variable(dtype) h = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='lstm_unit', type='lstm_unit',
...@@ -3396,7 +3409,7 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None): ...@@ -3396,7 +3409,7 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
""" """
helper = LayerHelper('reduce_sum', **locals()) helper = LayerHelper('reduce_sum', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list): if dim is not None and not isinstance(dim, list):
dim = [dim] dim = [dim]
helper.append_op( helper.append_op(
...@@ -3453,7 +3466,7 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None): ...@@ -3453,7 +3466,7 @@ def reduce_mean(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_mean(x, dim=[0, 1]) # [4.0, 5.0] fluid.layers.reduce_mean(x, dim=[0, 1]) # [4.0, 5.0]
""" """
helper = LayerHelper('reduce_mean', **locals()) helper = LayerHelper('reduce_mean', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list): if dim is not None and not isinstance(dim, list):
dim = [dim] dim = [dim]
helper.append_op( helper.append_op(
...@@ -3508,7 +3521,7 @@ def reduce_max(input, dim=None, keep_dim=False, name=None): ...@@ -3508,7 +3521,7 @@ def reduce_max(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_max(x, dim=[0, 1]) # [7.0, 8.0] fluid.layers.reduce_max(x, dim=[0, 1]) # [7.0, 8.0]
""" """
helper = LayerHelper('reduce_max', **locals()) helper = LayerHelper('reduce_max', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list): if dim is not None and not isinstance(dim, list):
dim = [dim] dim = [dim]
helper.append_op( helper.append_op(
...@@ -3563,7 +3576,7 @@ def reduce_min(input, dim=None, keep_dim=False, name=None): ...@@ -3563,7 +3576,7 @@ def reduce_min(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_min(x, dim=[0, 1]) # [1.0, 2.0] fluid.layers.reduce_min(x, dim=[0, 1]) # [1.0, 2.0]
""" """
helper = LayerHelper('reduce_min', **locals()) helper = LayerHelper('reduce_min', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list): if dim is not None and not isinstance(dim, list):
dim = [dim] dim = [dim]
helper.append_op( helper.append_op(
...@@ -3619,7 +3632,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): ...@@ -3619,7 +3632,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None):
fluid.layers.reduce_prod(x, dim=[0, 1]) # [105.0, 384.0] fluid.layers.reduce_prod(x, dim=[0, 1]) # [105.0, 384.0]
""" """
helper = LayerHelper('reduce_prod', **locals()) helper = LayerHelper('reduce_prod', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
if dim is not None and not isinstance(dim, list): if dim is not None and not isinstance(dim, list):
dim = [dim] dim = [dim]
helper.append_op( helper.append_op(
...@@ -3679,7 +3692,7 @@ def split(input, num_or_sections, dim=-1, name=None): ...@@ -3679,7 +3692,7 @@ def split(input, num_or_sections, dim=-1, name=None):
dim], 'len(num_or_sections) must not be more than input.shape[dim].' dim], 'len(num_or_sections) must not be more than input.shape[dim].'
num = len(num_or_sections) num = len(num_or_sections)
outs = [ outs = [
helper.create_tmp_variable(dtype=helper.input_dtype()) helper.create_variable_for_type_inference(dtype=helper.input_dtype())
for i in range(num) for i in range(num)
] ]
helper.append_op( helper.append_op(
...@@ -3736,8 +3749,8 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None): ...@@ -3736,8 +3749,8 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
axis = 0 axis = 0
helper = LayerHelper("l2_normalize", **locals()) helper = LayerHelper("l2_normalize", **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
norm = helper.create_tmp_variable(dtype=x.dtype) norm = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type="norm", type="norm",
inputs={"X": x}, inputs={"X": x},
...@@ -3846,7 +3859,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None): ...@@ -3846,7 +3859,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
__check_input(x, y) __check_input(x, y)
helper = LayerHelper('matmul', **locals()) helper = LayerHelper('matmul', **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='matmul', type='matmul',
inputs={'X': x, inputs={'X': x,
...@@ -3917,8 +3930,8 @@ def topk(input, k, name=None): ...@@ -3917,8 +3930,8 @@ def topk(input, k, name=None):
top5_values, top5_indices = layers.topk(input, k=5) top5_values, top5_indices = layers.topk(input, k=5)
""" """
helper = LayerHelper("top_k", **locals()) helper = LayerHelper("top_k", **locals())
values = helper.create_tmp_variable(dtype=input.dtype) values = helper.create_variable_for_type_inference(dtype=input.dtype)
indices = helper.create_tmp_variable(dtype="int64") indices = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op( helper.append_op(
type="top_k", type="top_k",
inputs={"X": [input]}, inputs={"X": [input]},
...@@ -3976,8 +3989,8 @@ def edit_distance(input, label, normalized=True, ignored_tokens=None): ...@@ -3976,8 +3989,8 @@ def edit_distance(input, label, normalized=True, ignored_tokens=None):
# remove some tokens from input and labels # remove some tokens from input and labels
if ignored_tokens is not None and len(ignored_tokens) > 0: if ignored_tokens is not None and len(ignored_tokens) > 0:
erased_input = helper.create_tmp_variable(dtype="int64") erased_input = helper.create_variable_for_type_inference(dtype="int64")
erased_label = helper.create_tmp_variable(dtype="int64") erased_label = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op( helper.append_op(
type="sequence_erase", type="sequence_erase",
...@@ -3994,8 +4007,8 @@ def edit_distance(input, label, normalized=True, ignored_tokens=None): ...@@ -3994,8 +4007,8 @@ def edit_distance(input, label, normalized=True, ignored_tokens=None):
label = erased_label label = erased_label
# edit distance op # edit distance op
edit_distance_out = helper.create_tmp_variable(dtype="int64") edit_distance_out = helper.create_variable_for_type_inference(dtype="int64")
sequence_num = helper.create_tmp_variable(dtype="int64") sequence_num = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op( helper.append_op(
type="edit_distance", type="edit_distance",
inputs={"Hyps": [input], inputs={"Hyps": [input],
...@@ -4070,7 +4083,7 @@ def ctc_greedy_decoder(input, blank, name=None): ...@@ -4070,7 +4083,7 @@ def ctc_greedy_decoder(input, blank, name=None):
_, topk_indices = topk(input, k=1) _, topk_indices = topk(input, k=1)
# ctc align op # ctc align op
ctc_out = helper.create_tmp_variable(dtype="int64") ctc_out = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op( helper.append_op(
type="ctc_align", type="ctc_align",
inputs={"Input": [topk_indices]}, inputs={"Input": [topk_indices]},
...@@ -4120,8 +4133,8 @@ def warpctc(input, label, blank=0, norm_by_times=False): ...@@ -4120,8 +4133,8 @@ def warpctc(input, label, blank=0, norm_by_times=False):
""" """
helper = LayerHelper('warpctc', **locals()) helper = LayerHelper('warpctc', **locals())
loss_out = helper.create_tmp_variable(dtype=input.dtype) loss_out = helper.create_variable_for_type_inference(dtype=input.dtype)
grad_out = helper.create_tmp_variable(dtype=input.dtype) grad_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op( helper.append_op(
type='warpctc', type='warpctc',
inputs={'Logits': [input], inputs={'Logits': [input],
...@@ -4182,7 +4195,7 @@ def sequence_reshape(input, new_dim): ...@@ -4182,7 +4195,7 @@ def sequence_reshape(input, new_dim):
x_reshaped = fluid.layers.sequence_reshape(input=x, new_dim=10) x_reshaped = fluid.layers.sequence_reshape(input=x, new_dim=10)
""" """
helper = LayerHelper('sequence_reshape', **locals()) helper = LayerHelper('sequence_reshape', **locals())
out = helper.create_tmp_variable(helper.input_dtype()) out = helper.create_variable_for_type_inference(helper.input_dtype())
helper.append_op( helper.append_op(
type='sequence_reshape', type='sequence_reshape',
inputs={'X': [input]}, inputs={'X': [input]},
...@@ -4279,9 +4292,9 @@ def nce(input, ...@@ -4279,9 +4292,9 @@ def nce(input,
is_bias=True, is_bias=True,
dtype=input.dtype) dtype=input.dtype)
inputs['Bias'] = b inputs['Bias'] = b
cost = helper.create_tmp_variable(dtype=input.dtype) cost = helper.create_variable_for_type_inference(dtype=input.dtype)
sample_logits = helper.create_tmp_variable(dtype=input.dtype) sample_logits = helper.create_variable_for_type_inference(dtype=input.dtype)
sample_labels = helper.create_tmp_variable(dtype=label.dtype) sample_labels = helper.create_variable_for_type_inference(dtype=label.dtype)
if num_neg_samples is None: if num_neg_samples is None:
num_neg_samples = 10 num_neg_samples = 10
...@@ -4357,8 +4370,8 @@ def hsigmoid(input, ...@@ -4357,8 +4370,8 @@ def hsigmoid(input,
helper = LayerHelper('hierarchical_sigmoid', **locals()) helper = LayerHelper('hierarchical_sigmoid', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
pre_out = helper.create_tmp_variable(dtype) pre_out = helper.create_variable_for_type_inference(dtype)
dim = input.shape[1] dim = input.shape[1]
if num_classes < 2: if num_classes < 2:
raise ValueError("num_classes must not be less than 2.") raise ValueError("num_classes must not be less than 2.")
...@@ -4418,8 +4431,8 @@ def transpose(x, perm, name=None): ...@@ -4418,8 +4431,8 @@ def transpose(x, perm, name=None):
(idx, perm[idx], len(x.shape))) (idx, perm[idx], len(x.shape)))
helper = LayerHelper('transpose', **locals()) helper = LayerHelper('transpose', **locals())
out = helper.create_tmp_variable(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
x_shape = helper.create_tmp_variable(x.dtype) x_shape = helper.create_variable_for_type_inference(x.dtype)
helper.append_op( helper.append_op(
type='transpose2', type='transpose2',
inputs={'X': [x]}, inputs={'X': [x]},
...@@ -4561,7 +4574,7 @@ def im2sequence(input, ...@@ -4561,7 +4574,7 @@ def im2sequence(input,
inputs["Y"] = input_image_size inputs["Y"] = input_image_size
attrs["out_stride"] = out_stride attrs["out_stride"] = out_stride
helper = LayerHelper('im2sequence', **locals()) helper = LayerHelper('im2sequence', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op( helper.append_op(
type='im2sequence', inputs=inputs, outputs={'Out': out}, attrs=attrs) type='im2sequence', inputs=inputs, outputs={'Out': out}, attrs=attrs)
return out return out
...@@ -4594,7 +4607,7 @@ def row_conv(input, future_context_size, param_attr=None, act=None): ...@@ -4594,7 +4607,7 @@ def row_conv(input, future_context_size, param_attr=None, act=None):
filter_shape = [future_context_size + 1, input.shape[1]] filter_shape = [future_context_size + 1, input.shape[1]]
filter_param = helper.create_parameter( filter_param = helper.create_parameter(
attr=helper.param_attr, shape=filter_shape, dtype=dtype) attr=helper.param_attr, shape=filter_shape, dtype=dtype)
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='row_conv', type='row_conv',
inputs={'X': [input], inputs={'X': [input],
...@@ -4627,7 +4640,7 @@ def multiplex(inputs, index): ...@@ -4627,7 +4640,7 @@ def multiplex(inputs, index):
raise ValueError("inputs should be a list object and contains at least " raise ValueError("inputs should be a list object and contains at least "
"2 elements.") "2 elements.")
out = helper.create_tmp_variable(inputs[0].dtype) out = helper.create_variable_for_type_inference(inputs[0].dtype)
helper.append_op( helper.append_op(
type='multiplex', type='multiplex',
inputs={'X': inputs, inputs={'X': inputs,
...@@ -4698,8 +4711,8 @@ def softmax_with_cross_entropy(logits, ...@@ -4698,8 +4711,8 @@ def softmax_with_cross_entropy(logits,
logits=fc, label=label) logits=fc, label=label)
""" """
helper = LayerHelper('softmax_with_cross_entropy', **locals()) helper = LayerHelper('softmax_with_cross_entropy', **locals())
softmax = helper.create_tmp_variable(dtype=logits.dtype) softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
loss = helper.create_tmp_variable(dtype=logits.dtype) loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
helper.append_op( helper.append_op(
type='softmax_with_cross_entropy', type='softmax_with_cross_entropy',
inputs={'Logits': logits, inputs={'Logits': logits,
...@@ -4749,8 +4762,8 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): ...@@ -4749,8 +4762,8 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None):
""" """
helper = LayerHelper('smooth_l1_loss', **locals()) helper = LayerHelper('smooth_l1_loss', **locals())
diff = helper.create_tmp_variable(dtype=x.dtype) diff = helper.create_variable_for_type_inference(dtype=x.dtype)
loss = helper.create_tmp_variable(dtype=x.dtype) loss = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='smooth_l1_loss', type='smooth_l1_loss',
inputs={ inputs={
...@@ -4783,7 +4796,7 @@ def one_hot(input, depth): ...@@ -4783,7 +4796,7 @@ def one_hot(input, depth):
one_hot_label = layers.one_hot(input=label, depth=10) one_hot_label = layers.one_hot(input=label, depth=10)
""" """
helper = LayerHelper("one_hot", **locals()) helper = LayerHelper("one_hot", **locals())
one_hot_out = helper.create_tmp_variable(dtype='float32') one_hot_out = helper.create_variable_for_type_inference(dtype='float32')
helper.append_op( helper.append_op(
type="one_hot", type="one_hot",
inputs={'X': input}, inputs={'X': input},
...@@ -4925,8 +4938,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None): ...@@ -4925,8 +4938,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=True, name=None):
"except one unknown dimension.") "except one unknown dimension.")
helper = LayerHelper("reshape2", **locals()) helper = LayerHelper("reshape2", **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
x_shape = helper.create_tmp_variable(dtype=x.dtype) x_shape = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type="reshape2", type="reshape2",
inputs=inputs, inputs=inputs,
...@@ -4975,8 +4988,8 @@ def squeeze(input, axes, name=None): ...@@ -4975,8 +4988,8 @@ def squeeze(input, axes, name=None):
y = layers.sequeeze(input=x, axes=[1]) y = layers.sequeeze(input=x, axes=[1])
""" """
helper = LayerHelper("squeeze", **locals()) helper = LayerHelper("squeeze", **locals())
out = helper.create_tmp_variable(dtype=input.dtype) out = helper.create_variable_for_type_inference(dtype=input.dtype)
x_shape = helper.create_tmp_variable(dtype=input.dtype) x_shape = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op( helper.append_op(
type="squeeze2", type="squeeze2",
inputs={"X": input}, inputs={"X": input},
...@@ -5012,8 +5025,8 @@ def unsqueeze(input, axes, name=None): ...@@ -5012,8 +5025,8 @@ def unsqueeze(input, axes, name=None):
y = layers.unsequeeze(input=x, axes=[1]) y = layers.unsequeeze(input=x, axes=[1])
""" """
helper = LayerHelper("unsqueeze", **locals()) helper = LayerHelper("unsqueeze", **locals())
out = helper.create_tmp_variable(dtype=input.dtype) out = helper.create_variable_for_type_inference(dtype=input.dtype)
x_shape = helper.create_tmp_variable(dtype=input.dtype) x_shape = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op( helper.append_op(
type="unsqueeze2", type="unsqueeze2",
inputs={"X": input}, inputs={"X": input},
...@@ -5103,7 +5116,7 @@ def lod_reset(x, y=None, target_lod=None): ...@@ -5103,7 +5116,7 @@ def lod_reset(x, y=None, target_lod=None):
out = layers.lod_reset(x=x, y=y) out = layers.lod_reset(x=x, y=y)
""" """
helper = LayerHelper("lod_reset", **locals()) helper = LayerHelper("lod_reset", **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
if y is not None: if y is not None:
helper.append_op( helper.append_op(
type="lod_reset", inputs={'X': x, type="lod_reset", inputs={'X': x,
...@@ -5172,8 +5185,9 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None): ...@@ -5172,8 +5185,9 @@ def lrn(input, n=5, k=1.0, alpha=1e-4, beta=0.75, name=None):
"dims of input must be 4(not %d), and it's order must be NCHW" % "dims of input must be 4(not %d), and it's order must be NCHW" %
(dims)) (dims))
mid_out = helper.create_tmp_variable(dtype=dtype, stop_gradient=True) mid_out = helper.create_variable_for_type_inference(
lrn_out = helper.create_tmp_variable(dtype) dtype=dtype, stop_gradient=True)
lrn_out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="lrn", type="lrn",
inputs={"X": input}, inputs={"X": input},
...@@ -5238,7 +5252,7 @@ def pad(x, paddings, pad_value=0., name=None): ...@@ -5238,7 +5252,7 @@ def pad(x, paddings, pad_value=0., name=None):
""" """
helper = LayerHelper('pad', input=x, **locals()) helper = LayerHelper('pad', input=x, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='pad', type='pad',
inputs={'X': x}, inputs={'X': x},
...@@ -5318,7 +5332,7 @@ def pad_constant_like(x, y, pad_value=0., name=None): ...@@ -5318,7 +5332,7 @@ def pad_constant_like(x, y, pad_value=0., name=None):
""" """
helper = LayerHelper('pad_constant_like', input=x, **locals()) helper = LayerHelper('pad_constant_like', input=x, **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='pad_constant_like', type='pad_constant_like',
inputs={'X': x, inputs={'X': x,
...@@ -5383,7 +5397,7 @@ def label_smooth(label, ...@@ -5383,7 +5397,7 @@ def label_smooth(label,
raise ValueError("The value of epsilon must be between 0 and 1.") raise ValueError("The value of epsilon must be between 0 and 1.")
helper = LayerHelper("label_smooth", **locals()) helper = LayerHelper("label_smooth", **locals())
label.stop_gradient = True label.stop_gradient = True
smooth_label = helper.create_tmp_variable(dtype) smooth_label = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="label_smooth", type="label_smooth",
inputs={"X": label, inputs={"X": label,
...@@ -5415,8 +5429,8 @@ def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0): ...@@ -5415,8 +5429,8 @@ def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0):
""" """
helper = LayerHelper('roi_pool', **locals()) helper = LayerHelper('roi_pool', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
pool_out = helper.create_tmp_variable(dtype) pool_out = helper.create_variable_for_type_inference(dtype)
argmaxes = helper.create_tmp_variable(dtype='int32') argmaxes = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op( helper.append_op(
type="roi_pool", type="roi_pool",
inputs={"X": input, inputs={"X": input,
...@@ -5464,7 +5478,7 @@ def roi_align(input, ...@@ -5464,7 +5478,7 @@ def roi_align(input,
""" """
helper = LayerHelper('roi_align', **locals()) helper = LayerHelper('roi_align', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
align_out = helper.create_tmp_variable(dtype) align_out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="roi_align", type="roi_align",
inputs={"X": input, inputs={"X": input,
...@@ -5589,7 +5603,7 @@ def image_resize(input, ...@@ -5589,7 +5603,7 @@ def image_resize(input,
out_h = int(input.shape[2] * scale) out_h = int(input.shape[2] * scale)
out_w = int(input.shape[3] * scale) out_w = int(input.shape[3] * scale)
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type=resample_methods[resample], type=resample_methods[resample],
inputs=inputs, inputs=inputs,
...@@ -5698,7 +5712,7 @@ def gather(input, index): ...@@ -5698,7 +5712,7 @@ def gather(input, index):
""" """
helper = LayerHelper('gather', **locals()) helper = LayerHelper('gather', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="gather", type="gather",
inputs={"X": input, inputs={"X": input,
...@@ -5738,7 +5752,7 @@ def scatter(input, index, updates, name=None): ...@@ -5738,7 +5752,7 @@ def scatter(input, index, updates, name=None):
""" """
helper = LayerHelper('scatter', **locals()) helper = LayerHelper('scatter', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="scatter", type="scatter",
inputs={"X": input, inputs={"X": input,
...@@ -5798,7 +5812,7 @@ def sequence_scatter(input, index, updates, name=None): ...@@ -5798,7 +5812,7 @@ def sequence_scatter(input, index, updates, name=None):
""" """
helper = LayerHelper('sequence_scatter', **locals()) helper = LayerHelper('sequence_scatter', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="sequence_scatter", type="sequence_scatter",
inputs={"X": input, inputs={"X": input,
...@@ -5828,7 +5842,7 @@ def random_crop(x, shape, seed=None): ...@@ -5828,7 +5842,7 @@ def random_crop(x, shape, seed=None):
""" """
helper = LayerHelper("random_crop", **locals()) helper = LayerHelper("random_crop", **locals())
dtype = x.dtype dtype = x.dtype
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
if seed is None: if seed is None:
seed = np.random.randint(-65536, 65536) seed = np.random.randint(-65536, 65536)
op_attrs = {"shape": shape} op_attrs = {"shape": shape}
...@@ -5874,7 +5888,7 @@ def log(x, name=None): ...@@ -5874,7 +5888,7 @@ def log(x, name=None):
""" """
helper = LayerHelper('log', **locals()) helper = LayerHelper('log', **locals())
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type="log", inputs={"X": x}, outputs={"Out": out}) helper.append_op(type="log", inputs={"X": x}, outputs={"Out": out})
return out return out
...@@ -5905,7 +5919,7 @@ def relu(x, name=None): ...@@ -5905,7 +5919,7 @@ def relu(x, name=None):
""" """
helper = LayerHelper('relu', **locals()) helper = LayerHelper('relu', **locals())
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type="relu", inputs={"X": x}, outputs={"Out": out}) helper.append_op(type="relu", inputs={"X": x}, outputs={"Out": out})
return out return out
...@@ -5944,9 +5958,9 @@ def mean_iou(input, label, num_classes): ...@@ -5944,9 +5958,9 @@ def mean_iou(input, label, num_classes):
""" """
helper = LayerHelper('mean_iou', **locals()) helper = LayerHelper('mean_iou', **locals())
dtype = helper.input_dtype() dtype = helper.input_dtype()
out_mean_iou = helper.create_tmp_variable(dtype='float32') out_mean_iou = helper.create_variable_for_type_inference(dtype='float32')
out_wrong = helper.create_tmp_variable(dtype='int32') out_wrong = helper.create_variable_for_type_inference(dtype='int32')
out_correct = helper.create_tmp_variable(dtype='int32') out_correct = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op( helper.append_op(
type="mean_iou", type="mean_iou",
inputs={"Predictions": input, inputs={"Predictions": input,
...@@ -6038,7 +6052,7 @@ def crop(x, shape=None, offsets=None, name=None): ...@@ -6038,7 +6052,7 @@ def crop(x, shape=None, offsets=None, name=None):
if offsets is None: if offsets is None:
offsets = [0] * len(x.shape) offsets = [0] * len(x.shape)
out = helper.create_tmp_variable(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
ipts = {'X': x} ipts = {'X': x}
attrs = {} attrs = {}
if isinstance(shape, Variable): if isinstance(shape, Variable):
...@@ -6118,7 +6132,7 @@ def rank_loss(label, left, right, name=None): ...@@ -6118,7 +6132,7 @@ def rank_loss(label, left, right, name=None):
if not (isinstance(right, Variable)): if not (isinstance(right, Variable)):
raise ValueError("The right should be a Variable") raise ValueError("The right should be a Variable")
out = helper.create_tmp_variable("float32") out = helper.create_variable_for_type_inference("float32")
helper.append_op( helper.append_op(
type='rank_loss', type='rank_loss',
...@@ -6164,8 +6178,8 @@ def margin_rank_loss(label, left, right, margin=0.1, name=None): ...@@ -6164,8 +6178,8 @@ def margin_rank_loss(label, left, right, margin=0.1, name=None):
raise ValueError("The left should be a Variable.") raise ValueError("The left should be a Variable.")
if not isinstance(right, Variable): if not isinstance(right, Variable):
raise ValueError("The right should be a Variable.") raise ValueError("The right should be a Variable.")
out = helper.create_tmp_variable(left.dtype) out = helper.create_variable_for_type_inference(left.dtype)
act = helper.create_tmp_variable(left.dtype) act = helper.create_variable_for_type_inference(left.dtype)
helper.append_op( helper.append_op(
type='margin_rank_loss', type='margin_rank_loss',
inputs={"Label": label, inputs={"Label": label,
...@@ -6250,7 +6264,7 @@ def pad2d(input, ...@@ -6250,7 +6264,7 @@ def pad2d(input,
helper = LayerHelper('pad2d', **locals()) helper = LayerHelper('pad2d', **locals())
dtype = helper.input_dtype(input_param_name='input') dtype = helper.input_dtype(input_param_name='input')
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='pad2d', type='pad2d',
inputs={'X': input}, inputs={'X': input},
...@@ -6279,7 +6293,7 @@ def elu(x, alpha=1.0, name=None): ...@@ -6279,7 +6293,7 @@ def elu(x, alpha=1.0, name=None):
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
""" """
helper = LayerHelper('elu', **locals()) helper = LayerHelper('elu', **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='elu', type='elu',
inputs={'X': x}, inputs={'X': x},
...@@ -6302,7 +6316,7 @@ def relu6(x, threshold=6.0, name=None): ...@@ -6302,7 +6316,7 @@ def relu6(x, threshold=6.0, name=None):
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
""" """
helper = LayerHelper('relu6', **locals()) helper = LayerHelper('relu6', **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='relu6', type='relu6',
inputs={'X': x}, inputs={'X': x},
...@@ -6325,7 +6339,7 @@ def pow(x, factor=1.0, name=None): ...@@ -6325,7 +6339,7 @@ def pow(x, factor=1.0, name=None):
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
""" """
helper = LayerHelper('pow', **locals()) helper = LayerHelper('pow', **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='pow', type='pow',
inputs={'X': x}, inputs={'X': x},
...@@ -6349,7 +6363,7 @@ def stanh(x, scale_a=2.0 / 3.0, scale_b=1.7159, name=None): ...@@ -6349,7 +6363,7 @@ def stanh(x, scale_a=2.0 / 3.0, scale_b=1.7159, name=None):
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
""" """
helper = LayerHelper('stanh', **locals()) helper = LayerHelper('stanh', **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='stanh', type='stanh',
inputs={'X': x}, inputs={'X': x},
...@@ -6374,7 +6388,7 @@ def hard_sigmoid(x, slope=0.2, offset=0.5, name=None): ...@@ -6374,7 +6388,7 @@ def hard_sigmoid(x, slope=0.2, offset=0.5, name=None):
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
""" """
helper = LayerHelper('hard_sigmoid', **locals()) helper = LayerHelper('hard_sigmoid', **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='hard_sigmoid', type='hard_sigmoid',
inputs={'X': x}, inputs={'X': x},
...@@ -6398,7 +6412,7 @@ def swish(x, beta=1.0, name=None): ...@@ -6398,7 +6412,7 @@ def swish(x, beta=1.0, name=None):
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
""" """
helper = LayerHelper('swish', **locals()) helper = LayerHelper('swish', **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='swish', type='swish',
inputs={'X': x}, inputs={'X': x},
...@@ -6450,7 +6464,7 @@ def prelu(x, mode, param_attr=None, name=None): ...@@ -6450,7 +6464,7 @@ def prelu(x, mode, param_attr=None, name=None):
dtype='float32', dtype='float32',
is_bias=False, is_bias=False,
default_initializer=Constant(1.0)) default_initializer=Constant(1.0))
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="prelu", type="prelu",
inputs={"X": x, inputs={"X": x,
...@@ -6474,7 +6488,7 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): ...@@ -6474,7 +6488,7 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
""" """
helper = LayerHelper('brelu', **locals()) helper = LayerHelper('brelu', **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='brelu', type='brelu',
inputs={'X': x}, inputs={'X': x},
...@@ -6497,7 +6511,7 @@ def leaky_relu(x, alpha=0.02, name=None): ...@@ -6497,7 +6511,7 @@ def leaky_relu(x, alpha=0.02, name=None):
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
""" """
helper = LayerHelper('leaky_relu', **locals()) helper = LayerHelper('leaky_relu', **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='leaky_relu', type='leaky_relu',
inputs={'X': x}, inputs={'X': x},
...@@ -6519,7 +6533,7 @@ def soft_relu(x, threshold=40.0, name=None): ...@@ -6519,7 +6533,7 @@ def soft_relu(x, threshold=40.0, name=None):
output(${out_type}): ${out_comment} output(${out_type}): ${out_comment}
""" """
helper = LayerHelper('soft_relu', **locals()) helper = LayerHelper('soft_relu', **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='soft_relu', type='soft_relu',
inputs={'X': x}, inputs={'X': x},
...@@ -6586,8 +6600,8 @@ def flatten(x, axis=1, name=None): ...@@ -6586,8 +6600,8 @@ def flatten(x, axis=1, name=None):
if not (isinstance(axis, int)) or axis > len(x.shape) or axis < 0: if not (isinstance(axis, int)) or axis > len(x.shape) or axis < 0:
raise ValueError("The axis should be a int, and in range [0, rank(x)]") raise ValueError("The axis should be a int, and in range [0, rank(x)]")
out = helper.create_tmp_variable(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
x_shape = helper.create_tmp_variable(x.dtype) x_shape = helper.create_variable_for_type_inference(x.dtype)
helper.append_op( helper.append_op(
type='flatten2', type='flatten2',
inputs={"X": x}, inputs={"X": x},
...@@ -6633,7 +6647,8 @@ def sequence_enumerate(input, win_size, pad_value=0, name=None): ...@@ -6633,7 +6647,8 @@ def sequence_enumerate(input, win_size, pad_value=0, name=None):
out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0) out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0)
""" """
helper = LayerHelper('sequence_enumerate', **locals()) helper = LayerHelper('sequence_enumerate', **locals())
out = helper.create_tmp_variable(helper.input_dtype(), stop_gradient=True) out = helper.create_variable_for_type_inference(
helper.input_dtype(), stop_gradient=True)
helper.append_op( helper.append_op(
type='sequence_enumerate', type='sequence_enumerate',
inputs={'X': input}, inputs={'X': input},
...@@ -6673,9 +6688,9 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None): ...@@ -6673,9 +6688,9 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None):
helper = LayerHelper('sequence_mask', **locals()) helper = LayerHelper('sequence_mask', **locals())
if name is None: if name is None:
out = helper.create_tmp_variable(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
else: else:
out = helper.create_tmp_variable(dtype=dtype, name=name) out = helper.create_variable_for_type_inference(dtype=dtype, name=name)
helper.append_op( helper.append_op(
type='sequence_mask', type='sequence_mask',
...@@ -6718,7 +6733,7 @@ def stack(x, axis=0): ...@@ -6718,7 +6733,7 @@ def stack(x, axis=0):
if not isinstance(x, list) and not isinstance(x, tuple): if not isinstance(x, list) and not isinstance(x, tuple):
x = [x] x = [x]
out = helper.create_tmp_variable(x[0].dtype) out = helper.create_variable_for_type_inference(x[0].dtype)
helper.append_op( helper.append_op(
type='stack', inputs={'X': x}, outputs={'Y': out}, type='stack', inputs={'X': x}, outputs={'Y': out},
attrs={'axis': axis}) attrs={'axis': axis})
...@@ -6756,7 +6771,7 @@ def unstack(x, axis=0, num=None): ...@@ -6756,7 +6771,7 @@ def unstack(x, axis=0, num=None):
outs = [] outs = []
for _ in num: for _ in num:
outs.append(helper.create_tmp_variable(x.dtype)) outs.append(helper.create_variable_for_type_inference(x.dtype))
helper.append_op( helper.append_op(
type='unstack', type='unstack',
...@@ -6808,7 +6823,7 @@ def expand(x, expand_times, name=None): ...@@ -6808,7 +6823,7 @@ def expand(x, expand_times, name=None):
""" """
helper = LayerHelper('expand', input=x, **locals()) helper = LayerHelper('expand', input=x, **locals())
dtype = helper.input_dtype(input_param_name='x') dtype = helper.input_dtype(input_param_name='x')
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='expand', type='expand',
inputs={'X': x}, inputs={'X': x},
...@@ -6847,7 +6862,7 @@ def uniform_random_batch_size_like(input, ...@@ -6847,7 +6862,7 @@ def uniform_random_batch_size_like(input,
""" """
helper = LayerHelper('uniform_random_batch_size_like', **locals()) helper = LayerHelper('uniform_random_batch_size_like', **locals())
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
c_dtype = convert_np_dtype_to_dtype_(dtype) c_dtype = convert_np_dtype_to_dtype_(dtype)
helper.append_op( helper.append_op(
type='uniform_random_batch_size_like', type='uniform_random_batch_size_like',
...@@ -6884,7 +6899,7 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'): ...@@ -6884,7 +6899,7 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'):
""" """
helper = LayerHelper('gaussian_random', **locals()) helper = LayerHelper('gaussian_random', **locals())
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
c_dtype = convert_np_dtype_to_dtype_(dtype) c_dtype = convert_np_dtype_to_dtype_(dtype)
helper.append_op( helper.append_op(
type='gaussian_random', type='gaussian_random',
...@@ -6919,7 +6934,7 @@ def sampling_id(x, min=0.0, max=1.0, seed=0, dtype='float32'): ...@@ -6919,7 +6934,7 @@ def sampling_id(x, min=0.0, max=1.0, seed=0, dtype='float32'):
""" """
helper = LayerHelper('sampling_id', **locals()) helper = LayerHelper('sampling_id', **locals())
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type='sampling_id', type='sampling_id',
inputs={'X': x}, inputs={'X': x},
...@@ -6958,7 +6973,7 @@ def gaussian_random_batch_size_like(input, ...@@ -6958,7 +6973,7 @@ def gaussian_random_batch_size_like(input,
""" """
helper = LayerHelper('gaussian_random_batch_size_like', **locals()) helper = LayerHelper('gaussian_random_batch_size_like', **locals())
out = helper.create_tmp_variable(dtype) out = helper.create_variable_for_type_inference(dtype)
c_dtype = convert_np_dtype_to_dtype_(dtype) c_dtype = convert_np_dtype_to_dtype_(dtype)
helper.append_op( helper.append_op(
type='gaussian_random_batch_size_like', type='gaussian_random_batch_size_like',
...@@ -6990,7 +7005,8 @@ def sum(x): ...@@ -6990,7 +7005,8 @@ def sum(x):
""" """
helper = LayerHelper('sum', **locals()) helper = LayerHelper('sum', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype('x')) out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('x'))
helper.append_op( helper.append_op(
type='sum', type='sum',
inputs={'X': x}, inputs={'X': x},
...@@ -7017,7 +7033,8 @@ def slice(input, axes, starts, ends): ...@@ -7017,7 +7033,8 @@ def slice(input, axes, starts, ends):
""" """
helper = LayerHelper('slice', **locals()) helper = LayerHelper('slice', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype('input')) out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('input'))
helper.append_op( helper.append_op(
type='slice', type='slice',
inputs={'Input': input}, inputs={'Input': input},
...@@ -7043,7 +7060,8 @@ def shape(input): ...@@ -7043,7 +7060,8 @@ def shape(input):
""" """
helper = LayerHelper('shape', **locals()) helper = LayerHelper('shape', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype('input')) out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype('input'))
helper.append_op( helper.append_op(
type='shape', inputs={'Input': input}, outputs={'Out': out}) type='shape', inputs={'Input': input}, outputs={'Out': out})
...@@ -7060,7 +7078,7 @@ def _elementwise_op(helper): ...@@ -7060,7 +7078,7 @@ def _elementwise_op(helper):
use_mkldnn = helper.kwargs.get('use_mkldnn', False) use_mkldnn = helper.kwargs.get('use_mkldnn', False)
name = helper.kwargs.get('name', None) name = helper.kwargs.get('name', None)
if name is None: if name is None:
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
out = helper.create_variable( out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False) name=name, dtype=x.dtype, persistable=False)
...@@ -7094,7 +7112,7 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None): ...@@ -7094,7 +7112,7 @@ def scale(x, scale=1.0, bias=0.0, bias_after_scale=True, act=None, name=None):
helper = LayerHelper('scale', **locals()) helper = LayerHelper('scale', **locals())
if name is None: if name is None:
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
out = helper.create_variable( out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False) name=name, dtype=x.dtype, persistable=False)
...@@ -7160,7 +7178,7 @@ def _logical_op(op_name, x, y, out=None, name=None, binary_op=True): ...@@ -7160,7 +7178,7 @@ def _logical_op(op_name, x, y, out=None, name=None, binary_op=True):
if out is None: if out is None:
if name is None: if name is None:
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
out = helper.create_variable( out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False) name=name, dtype=x.dtype, persistable=False)
...@@ -7268,7 +7286,7 @@ def clip(x, min, max, name=None): ...@@ -7268,7 +7286,7 @@ def clip(x, min, max, name=None):
helper = LayerHelper("clip", **locals()) helper = LayerHelper("clip", **locals())
if name is None: if name is None:
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
out = helper.create_variable( out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False) name=name, dtype=x.dtype, persistable=False)
...@@ -7300,7 +7318,7 @@ def clip_by_norm(x, max_norm, name=None): ...@@ -7300,7 +7318,7 @@ def clip_by_norm(x, max_norm, name=None):
helper = LayerHelper("clip_by_norm", **locals()) helper = LayerHelper("clip_by_norm", **locals())
if name is None: if name is None:
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
out = helper.create_variable( out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False) name=name, dtype=x.dtype, persistable=False)
...@@ -7330,7 +7348,7 @@ def mean(x, name=None): ...@@ -7330,7 +7348,7 @@ def mean(x, name=None):
helper = LayerHelper("mean", **locals()) helper = LayerHelper("mean", **locals())
if name is None: if name is None:
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
out = helper.create_variable( out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False) name=name, dtype=x.dtype, persistable=False)
...@@ -7360,7 +7378,7 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None): ...@@ -7360,7 +7378,7 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
helper = LayerHelper("mul", **locals()) helper = LayerHelper("mul", **locals())
if name is None: if name is None:
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
out = helper.create_variable( out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False) name=name, dtype=x.dtype, persistable=False)
...@@ -7394,7 +7412,7 @@ def sigmoid_cross_entropy_with_logits(x, label, name=None): ...@@ -7394,7 +7412,7 @@ def sigmoid_cross_entropy_with_logits(x, label, name=None):
helper = LayerHelper("sigmoid_cross_entropy_with_logits", **locals()) helper = LayerHelper("sigmoid_cross_entropy_with_logits", **locals())
if name is None: if name is None:
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
out = helper.create_variable( out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False) name=name, dtype=x.dtype, persistable=False)
...@@ -7424,7 +7442,7 @@ def maxout(x, groups, name=None): ...@@ -7424,7 +7442,7 @@ def maxout(x, groups, name=None):
helper = LayerHelper("maxout", **locals()) helper = LayerHelper("maxout", **locals())
if name is None: if name is None:
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
out = helper.create_variable( out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False) name=name, dtype=x.dtype, persistable=False)
...@@ -7463,7 +7481,7 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None): ...@@ -7463,7 +7481,7 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
helper = LayerHelper("affine_channel", **locals()) helper = LayerHelper("affine_channel", **locals())
if name is None: if name is None:
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
else: else:
out = helper.create_variable( out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False) name=name, dtype=x.dtype, persistable=False)
......
...@@ -152,7 +152,7 @@ def cast(x, dtype): ...@@ -152,7 +152,7 @@ def cast(x, dtype):
result = fluid.layers.cast(x=data, dtype='float64') result = fluid.layers.cast(x=data, dtype='float64')
""" """
helper = LayerHelper('cast', **locals()) helper = LayerHelper('cast', **locals())
out = helper.create_tmp_variable(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op( helper.append_op(
type='cast', type='cast',
inputs={'X': [x]}, inputs={'X': [x]},
...@@ -184,7 +184,7 @@ def concat(input, axis=0, name=None): ...@@ -184,7 +184,7 @@ def concat(input, axis=0, name=None):
out = fluid.layers.concat(input=[Efirst, Esecond, Ethird, Efourth]) out = fluid.layers.concat(input=[Efirst, Esecond, Ethird, Efourth])
""" """
helper = LayerHelper('concat', **locals()) helper = LayerHelper('concat', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op( helper.append_op(
type='concat', type='concat',
inputs={'X': input}, inputs={'X': input},
...@@ -221,7 +221,8 @@ def sums(input, out=None): ...@@ -221,7 +221,8 @@ def sums(input, out=None):
""" """
helper = LayerHelper('sum', **locals()) helper = LayerHelper('sum', **locals())
if out is None: if out is None:
out = helper.create_tmp_variable(dtype=helper.input_dtype()) out = helper.create_variable_for_type_inference(
dtype=helper.input_dtype())
helper.append_op( helper.append_op(
type='sum', type='sum',
inputs={'X': input}, inputs={'X': input},
...@@ -252,7 +253,7 @@ def assign(input, output=None): ...@@ -252,7 +253,7 @@ def assign(input, output=None):
""" """
helper = LayerHelper('assign', **locals()) helper = LayerHelper('assign', **locals())
if output is None: if output is None:
output = helper.create_tmp_variable(dtype=input.dtype) output = helper.create_variable_for_type_inference(dtype=input.dtype)
if isinstance(input, Variable): if isinstance(input, Variable):
helper.append_op( helper.append_op(
type='assign', inputs={'X': [input]}, outputs={'Out': [output]}) type='assign', inputs={'X': [input]}, outputs={'Out': [output]})
...@@ -311,7 +312,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): ...@@ -311,7 +312,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None):
helper = LayerHelper("fill_constant", **locals()) helper = LayerHelper("fill_constant", **locals())
if out is None: if out is None:
out = helper.create_tmp_variable(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op( helper.append_op(
type='fill_constant', type='fill_constant',
inputs={}, inputs={},
...@@ -358,7 +359,7 @@ def fill_constant_batch_size_like(input, ...@@ -358,7 +359,7 @@ def fill_constant_batch_size_like(input,
${out_comment}. ${out_comment}.
""" """
helper = LayerHelper("fill_constant_batch_size_like", **locals()) helper = LayerHelper("fill_constant_batch_size_like", **locals())
out = helper.create_tmp_variable(dtype=dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op( helper.append_op(
type='fill_constant_batch_size_like', type='fill_constant_batch_size_like',
inputs={'Input': input}, inputs={'Input': input},
...@@ -396,7 +397,7 @@ def argmin(x, axis=0): ...@@ -396,7 +397,7 @@ def argmin(x, axis=0):
out = fluid.layers.argmin(x=in, axis=-1) out = fluid.layers.argmin(x=in, axis=-1)
""" """
helper = LayerHelper("arg_min", **locals()) helper = LayerHelper("arg_min", **locals())
out = helper.create_tmp_variable(VarDesc.VarType.INT64) out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64)
helper.append_op( helper.append_op(
type='arg_min', type='arg_min',
inputs={'X': x}, inputs={'X': x},
...@@ -427,7 +428,7 @@ def argmax(x, axis=0): ...@@ -427,7 +428,7 @@ def argmax(x, axis=0):
out = fluid.layers.argmax(x=in, axis=-1) out = fluid.layers.argmax(x=in, axis=-1)
""" """
helper = LayerHelper("arg_max", **locals()) helper = LayerHelper("arg_max", **locals())
out = helper.create_tmp_variable(VarDesc.VarType.INT64) out = helper.create_variable_for_type_inference(VarDesc.VarType.INT64)
helper.append_op( helper.append_op(
type='arg_max', type='arg_max',
inputs={'X': x}, inputs={'X': x},
...@@ -477,8 +478,10 @@ def argsort(input, axis=-1, name=None): ...@@ -477,8 +478,10 @@ def argsort(input, axis=-1, name=None):
out, indices = fluid.layers.argsort(input, axis=0) out, indices = fluid.layers.argsort(input, axis=0)
""" """
helper = LayerHelper("argsort", **locals()) helper = LayerHelper("argsort", **locals())
out = helper.create_tmp_variable(dtype=input.dtype, stop_gradient=True) out = helper.create_variable_for_type_inference(
ids = helper.create_tmp_variable(VarDesc.VarType.INT64, stop_gradient=True) dtype=input.dtype, stop_gradient=True)
ids = helper.create_variable_for_type_inference(
VarDesc.VarType.INT64, stop_gradient=True)
helper.append_op( helper.append_op(
type='argsort', type='argsort',
inputs={'X': input}, inputs={'X': input},
...@@ -562,7 +565,7 @@ def reverse(x, axis): ...@@ -562,7 +565,7 @@ def reverse(x, axis):
if isinstance(axis, int): if isinstance(axis, int):
axis = [axis] axis = [axis]
helper = LayerHelper("reverse", **locals()) helper = LayerHelper("reverse", **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='reverse', type='reverse',
inputs={'Input': x}, inputs={'Input': x},
...@@ -654,7 +657,7 @@ def has_inf(x): ...@@ -654,7 +657,7 @@ def has_inf(x):
Variable: The tensor variable storing the output, only a bool value. Variable: The tensor variable storing the output, only a bool value.
""" """
helper = LayerHelper("isinf", **locals()) helper = LayerHelper("isinf", **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type="isinf", inputs={"X": x}, outputs={"Out": out}) helper.append_op(type="isinf", inputs={"X": x}, outputs={"Out": out})
return out return out
...@@ -670,7 +673,7 @@ def has_nan(x): ...@@ -670,7 +673,7 @@ def has_nan(x):
Variable: The tensor variable storing the output, only a bool value. Variable: The tensor variable storing the output, only a bool value.
""" """
helper = LayerHelper("isnan", **locals()) helper = LayerHelper("isnan", **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type="isnan", inputs={"X": x}, outputs={"Out": out}) helper.append_op(type="isnan", inputs={"X": x}, outputs={"Out": out})
return out return out
...@@ -687,6 +690,6 @@ def isfinite(x): ...@@ -687,6 +690,6 @@ def isfinite(x):
Variable: The tensor variable storing the output, contains a bool value. Variable: The tensor variable storing the output, contains a bool value.
""" """
helper = LayerHelper("isfinite", **locals()) helper = LayerHelper("isfinite", **locals())
out = helper.create_tmp_variable(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type="isfinite", inputs={"X": x}, outputs={"Out": out}) helper.append_op(type="isfinite", inputs={"X": x}, outputs={"Out": out})
return out return out
...@@ -151,7 +151,7 @@ class L2DecayRegularizer(WeightDecayRegularizer): ...@@ -151,7 +151,7 @@ class L2DecayRegularizer(WeightDecayRegularizer):
decay = block.create_var( decay = block.create_var(
dtype="float32", dtype="float32",
shape=param.shape, shape=param.shape,
type=core.VarDesc.VarType.SELECTED_ROWS) type=core.VarDesc.VarType.LOD_TENSOR)
block.append_op( block.append_op(
type='extract_rows', inputs={'X': grad}, outputs={'Out': idx}) type='extract_rows', inputs={'X': grad}, outputs={'Out': idx})
block.append_op( block.append_op(
...@@ -228,7 +228,7 @@ class L1DecayRegularizer(WeightDecayRegularizer): ...@@ -228,7 +228,7 @@ class L1DecayRegularizer(WeightDecayRegularizer):
decay = block.create_var( decay = block.create_var(
dtype="float32", dtype="float32",
shape=param.shape, shape=param.shape,
type=core.VarDesc.VarType.SELECTED_ROWS) type=core.VarDesc.VarType.LOD_TENSOR)
block.append_op( block.append_op(
type='extract_rows', inputs={'X': grad}, outputs={'Out': idx}) type='extract_rows', inputs={'X': grad}, outputs={'Out': idx})
block.append_op( block.append_op(
......
if(NOT APPLE) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory")
set(PYTHON_TESTS_DIR ${CMAKE_CURRENT_BINARY_DIR} CACHE PATH "python tests directory")
else()
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
endif(NOT APPLE)
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
......
...@@ -301,7 +301,7 @@ class TestRpnTargetAssign(unittest.TestCase): ...@@ -301,7 +301,7 @@ class TestRpnTargetAssign(unittest.TestCase):
dtype='float32', dtype='float32',
lod_level=1, lod_level=1,
append_batch_size=False) append_batch_size=False)
pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign( pred_scores, pred_loc, tgt_lbl, tgt_bbox, bbox_inside_weight = layers.rpn_target_assign(
bbox_pred=bbox_pred, bbox_pred=bbox_pred,
cls_logits=cls_logits, cls_logits=cls_logits,
anchor_box=anchor_box, anchor_box=anchor_box,
...@@ -313,15 +313,18 @@ class TestRpnTargetAssign(unittest.TestCase): ...@@ -313,15 +313,18 @@ class TestRpnTargetAssign(unittest.TestCase):
rpn_straddle_thresh=0.0, rpn_straddle_thresh=0.0,
rpn_fg_fraction=0.5, rpn_fg_fraction=0.5,
rpn_positive_overlap=0.7, rpn_positive_overlap=0.7,
rpn_negative_overlap=0.3) rpn_negative_overlap=0.3,
use_random=False)
self.assertIsNotNone(pred_scores) self.assertIsNotNone(pred_scores)
self.assertIsNotNone(pred_loc) self.assertIsNotNone(pred_loc)
self.assertIsNotNone(tgt_lbl) self.assertIsNotNone(tgt_lbl)
self.assertIsNotNone(tgt_bbox) self.assertIsNotNone(tgt_bbox)
self.assertIsNotNone(bbox_inside_weight)
assert pred_scores.shape[1] == 1 assert pred_scores.shape[1] == 1
assert pred_loc.shape[1] == 4 assert pred_loc.shape[1] == 4
assert pred_loc.shape[1] == tgt_bbox.shape[1] assert pred_loc.shape[1] == tgt_bbox.shape[1]
print(str(program))
class TestGenerateProposals(unittest.TestCase): class TestGenerateProposals(unittest.TestCase):
......
...@@ -78,9 +78,9 @@ if(WITH_DISTRIBUTE) ...@@ -78,9 +78,9 @@ if(WITH_DISTRIBUTE)
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200)
py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext) py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext)
set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000) set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000)
# TODO: fix this test
#py_test_modules(test_dist_transformer MODULES test_dist_transformer) py_test_modules(test_dist_transformer MODULES test_dist_transformer)
#set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000) set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
endif(NOT APPLE) endif(NOT APPLE)
py_test_modules(test_dist_transpiler MODULES test_dist_transpiler) py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
endif() endif()
......
...@@ -35,7 +35,7 @@ import paddle ...@@ -35,7 +35,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid import core from paddle.fluid import core
from test_dist_base import TestDistRunnerBase, runtime_main from test_dist_base import TestDistRunnerBase, runtime_main, RUN_STEP
import paddle.compat as cpt import paddle.compat as cpt
from paddle.compat import long_type from paddle.compat import long_type
...@@ -562,18 +562,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -562,18 +562,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
for pass_id in six.moves.xrange(TrainTaskConfig.pass_num): for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time() pass_start_time = time.time()
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
if batch_id >= 5: if batch_id >= RUN_STEP:
break break
feed_list = [] feed_list = []
total_num_token = 0 total_num_token = 0
#if TrainTaskConfig.local:
# lr_rate = lr_scheduler.update_learning_rate()
#for place_id, data_buffer in enumerate(
# split_data(
# data, num_part=dev_count)):
if TrainTaskConfig.local: if TrainTaskConfig.local:
lr_rate = lr_scheduler.update_learning_rate() lr_rate = lr_scheduler.update_learning_rate()
...@@ -619,12 +613,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -619,12 +613,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
init = True init = True
# Validate and save the model for inference. # Validate and save the model for inference.
if batch_id == 0 or batch_id == 4: if TrainTaskConfig.val_file_pattern is not None:
if TrainTaskConfig.val_file_pattern is not None: val_avg_cost, val_ppl = test()
val_avg_cost, val_ppl = test() print("[%f]" % val_avg_cost)
print("[%f]" % val_avg_cost) else:
else: assert (False)
assert (False)
#import transformer_reader as reader #import transformer_reader as reader
...@@ -1701,7 +1694,7 @@ class DistTransformer2x2(TestDistRunnerBase): ...@@ -1701,7 +1694,7 @@ class DistTransformer2x2(TestDistRunnerBase):
def run_trainer(self, args): def run_trainer(self, args):
TrainTaskConfig.use_gpu = args.use_cuda TrainTaskConfig.use_gpu = args.use_cuda
sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model( sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model(
args.is_dist, not args.sync_mode) args.is_dist, not args.sync_mode)
if args.is_dist: if args.is_dist:
......
...@@ -40,7 +40,8 @@ class TestDistMnistAsync(TestDistBase): ...@@ -40,7 +40,8 @@ class TestDistMnistAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._use_reduce = False self._use_reduce = False
def test_dist_train(self): # FIXME(typhoonzero): fix async mode test later
def no_test_dist_train(self):
self.check_with_place("dist_mnist.py", delta=200) self.check_with_place("dist_mnist.py", delta=200)
......
...@@ -40,7 +40,8 @@ class TestDistSeResneXt2x2Async(TestDistBase): ...@@ -40,7 +40,8 @@ class TestDistSeResneXt2x2Async(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._use_reader_alloc = False self._use_reader_alloc = False
def test_dist_train(self): #FIXME(typhoonzero): fix async mode later
def no_test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=100) self.check_with_place("dist_se_resnext.py", delta=100)
......
...@@ -42,7 +42,8 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase): ...@@ -42,7 +42,8 @@ class TestDistSimnetBow2x2DenseAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._enforce_place = "CPU" self._enforce_place = "CPU"
def test_simnet_bow(self): #FIXME(typhoonzero): fix async tests later
def no_test_simnet_bow(self):
need_envs = { need_envs = {
"IS_DISTRIBUTED": '0', "IS_DISTRIBUTED": '0',
"IS_SPARSE": '0', "IS_SPARSE": '0',
...@@ -78,7 +79,8 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase): ...@@ -78,7 +79,8 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase):
self._sync_mode = False self._sync_mode = False
self._enforce_place = "CPU" self._enforce_place = "CPU"
def test_simnet_bow(self): #FIXME(typhoonzero): fix async tests later
def no_test_simnet_bow(self):
need_envs = { need_envs = {
"IS_DISTRIBUTED": '0', "IS_DISTRIBUTED": '0',
"IS_SPARSE": '1', "IS_SPARSE": '1',
......
...@@ -61,7 +61,8 @@ class TestDistTransformer2x2Sync(TestDistBase): ...@@ -61,7 +61,8 @@ class TestDistTransformer2x2Sync(TestDistBase):
def test_dist_train(self): def test_dist_train(self):
download_files() download_files()
self.check_with_place("dist_transformer.py", delta=1e-5) self.check_with_place(
"dist_transformer.py", delta=1e-5, check_error_log=False)
class TestDistTransformer2x2Async(TestDistBase): class TestDistTransformer2x2Async(TestDistBase):
...@@ -70,7 +71,8 @@ class TestDistTransformer2x2Async(TestDistBase): ...@@ -70,7 +71,8 @@ class TestDistTransformer2x2Async(TestDistBase):
def test_dist_train(self): def test_dist_train(self):
download_files() download_files()
self.check_with_place("dist_transformer.py", delta=1.0) self.check_with_place(
"dist_transformer.py", delta=1.0, check_error_log=False)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp): ...@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp):
self.D = 8 self.D = 8
class TestFusionGRUOpMD3(TestFusionGRUOp):
def set_confs(self):
self.M = 17
self.D = 15
class TestFusionGRUOpBS1(TestFusionGRUOp): class TestFusionGRUOpBS1(TestFusionGRUOp):
def set_confs(self): def set_confs(self):
self.lod = [[3]] self.lod = [[3]]
......
...@@ -50,8 +50,10 @@ def rpn_target_assign(anchor_by_gt_overlap, ...@@ -50,8 +50,10 @@ def rpn_target_assign(anchor_by_gt_overlap,
fg_inds, size=(len(fg_inds) - num_fg), replace=False) fg_inds, size=(len(fg_inds) - num_fg), replace=False)
else: else:
disable_inds = fg_inds[num_fg:] disable_inds = fg_inds[num_fg:]
labels[disable_inds] = -1 labels[disable_inds] = -1
fg_inds = np.where(labels == 1)[0] fg_inds = np.where(labels == 1)[0]
bbox_inside_weight = np.zeros((len(fg_inds), 4), dtype=np.float32)
num_bg = rpn_batch_size_per_im - np.sum(labels == 1) num_bg = rpn_batch_size_per_im - np.sum(labels == 1)
bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0] bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0]
...@@ -59,18 +61,27 @@ def rpn_target_assign(anchor_by_gt_overlap, ...@@ -59,18 +61,27 @@ def rpn_target_assign(anchor_by_gt_overlap,
enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)] enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)]
else: else:
enable_inds = bg_inds[:num_bg] enable_inds = bg_inds[:num_bg]
fg_fake_inds = np.array([], np.int32)
fg_value = np.array([fg_inds[0]], np.int32)
fake_num = 0
for bg_id in enable_inds:
if bg_id in fg_inds:
fake_num += 1
fg_fake_inds = np.hstack([fg_fake_inds, fg_value])
labels[enable_inds] = 0 labels[enable_inds] = 0
bbox_inside_weight[fake_num:, :] = 1
fg_inds = np.where(labels == 1)[0] fg_inds = np.where(labels == 1)[0]
bg_inds = np.where(labels == 0)[0] bg_inds = np.where(labels == 0)[0]
loc_index = np.hstack([fg_fake_inds, fg_inds])
loc_index = fg_inds score_index = np.hstack([fg_inds, bg_inds])
score_index = np.hstack((fg_inds, bg_inds))
labels = labels[score_index] labels = labels[score_index]
assert not np.any(labels == -1), "Wrong labels with -1" assert not np.any(labels == -1), "Wrong labels with -1"
gt_inds = anchor_to_gt_argmax[fg_inds] gt_inds = anchor_to_gt_argmax[loc_index]
return loc_index, score_index, labels, gt_inds return loc_index, score_index, labels, gt_inds, bbox_inside_weight
def get_anchor(n, c, h, w): def get_anchor(n, c, h, w):
...@@ -123,9 +134,12 @@ def rpn_target_assign_in_python(all_anchors, ...@@ -123,9 +134,12 @@ def rpn_target_assign_in_python(all_anchors,
gt_boxes_slice = gt_boxes_slice[not_crowd_inds] gt_boxes_slice = gt_boxes_slice[not_crowd_inds]
iou = _bbox_overlaps(inside_anchors, gt_boxes_slice) iou = _bbox_overlaps(inside_anchors, gt_boxes_slice)
loc_inds, score_inds, labels, gt_inds = rpn_target_assign( loc_inds, score_inds, labels, gt_inds, bbox_inside_weight = \
iou, rpn_batch_size_per_im, rpn_positive_overlap, rpn_target_assign(iou, rpn_batch_size_per_im,
rpn_negative_overlap, rpn_fg_fraction, use_random) rpn_positive_overlap,
rpn_negative_overlap,
rpn_fg_fraction,
use_random)
# unmap to all anchor # unmap to all anchor
loc_inds = inds_inside[loc_inds] loc_inds = inds_inside[loc_inds]
score_inds = inds_inside[score_inds] score_inds = inds_inside[score_inds]
...@@ -139,6 +153,7 @@ def rpn_target_assign_in_python(all_anchors, ...@@ -139,6 +153,7 @@ def rpn_target_assign_in_python(all_anchors,
score_indexes = score_inds score_indexes = score_inds
tgt_labels = labels tgt_labels = labels
tgt_bboxes = box_deltas tgt_bboxes = box_deltas
bbox_inside_weights = bbox_inside_weight
else: else:
loc_indexes = np.concatenate( loc_indexes = np.concatenate(
[loc_indexes, loc_inds + i * anchor_num]) [loc_indexes, loc_inds + i * anchor_num])
...@@ -146,8 +161,10 @@ def rpn_target_assign_in_python(all_anchors, ...@@ -146,8 +161,10 @@ def rpn_target_assign_in_python(all_anchors,
[score_indexes, score_inds + i * anchor_num]) [score_indexes, score_inds + i * anchor_num])
tgt_labels = np.concatenate([tgt_labels, labels]) tgt_labels = np.concatenate([tgt_labels, labels])
tgt_bboxes = np.vstack([tgt_bboxes, box_deltas]) tgt_bboxes = np.vstack([tgt_bboxes, box_deltas])
bbox_inside_weights = np.vstack([bbox_inside_weights, \
bbox_inside_weight])
return loc_indexes, score_indexes, tgt_bboxes, tgt_labels return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights
class TestRpnTargetAssignOp(OpTest): class TestRpnTargetAssignOp(OpTest):
...@@ -182,10 +199,12 @@ class TestRpnTargetAssignOp(OpTest): ...@@ -182,10 +199,12 @@ class TestRpnTargetAssignOp(OpTest):
rpn_fg_fraction = 0.5 rpn_fg_fraction = 0.5
use_random = False use_random = False
loc_index, score_index, tgt_bbox, labels = rpn_target_assign_in_python( loc_index, score_index, tgt_bbox, labels, bbox_inside_weights = \
all_anchors, gt_boxes, is_crowd, im_info, lod, rpn_straddle_thresh, rpn_target_assign_in_python(all_anchors, gt_boxes, is_crowd,
rpn_batch_size_per_im, rpn_positive_overlap, rpn_negative_overlap, im_info, lod, rpn_straddle_thresh,
rpn_fg_fraction, use_random) rpn_batch_size_per_im, rpn_positive_overlap,
rpn_negative_overlap,
rpn_fg_fraction, use_random)
labels = labels[:, np.newaxis] labels = labels[:, np.newaxis]
self.op_type = "rpn_target_assign" self.op_type = "rpn_target_assign"
...@@ -207,7 +226,8 @@ class TestRpnTargetAssignOp(OpTest): ...@@ -207,7 +226,8 @@ class TestRpnTargetAssignOp(OpTest):
'LocationIndex': loc_index.astype('int32'), 'LocationIndex': loc_index.astype('int32'),
'ScoreIndex': score_index.astype('int32'), 'ScoreIndex': score_index.astype('int32'),
'TargetBBox': tgt_bbox.astype('float32'), 'TargetBBox': tgt_bbox.astype('float32'),
'TargetLabel': labels.astype('int32') 'TargetLabel': labels.astype('int32'),
'BBoxInsideWeight': bbox_inside_weights.astype('float32')
} }
def test_check_output(self): def test_check_output(self):
......
...@@ -30,7 +30,6 @@ class TestSliceVar(unittest.TestCase): ...@@ -30,7 +30,6 @@ class TestSliceVar(unittest.TestCase):
var = program.global_block().create_var( var = program.global_block().create_var(
name=str(random.randint(10000, 99999)), name=str(random.randint(10000, 99999)),
persistable=True, persistable=True,
# dtype=core.VarDesc.VarType.LOD_TENSOR,
shape=shape) shape=shape)
var_list.append(var) var_list.append(var)
blocks = slice_variable(var_list, 10, min_size) blocks = slice_variable(var_list, 10, min_size)
......
...@@ -21,22 +21,27 @@ from op_test import OpTest ...@@ -21,22 +21,27 @@ from op_test import OpTest
class TestTopkOp(OpTest): class TestTopkOp(OpTest):
def setUp(self): def setUp(self):
self.set_args()
self.op_type = "top_k" self.op_type = "top_k"
k = 1 k = self.top_k
input = np.random.random((32, 84)).astype("float32") input = np.random.random((self.row, k)).astype("float32")
output = np.ndarray((32, k)) output = np.ndarray((self.row, k))
indices = np.ndarray((32, k)).astype("int64") indices = np.ndarray((self.row, k)).astype("int64")
self.inputs = {'X': input} self.inputs = {'X': input}
self.attrs = {'k': k} self.attrs = {'k': k}
for rowid in range(32): for rowid in range(self.row):
row = input[rowid] row = input[rowid]
output[rowid] = np.sort(row)[-k:] output[rowid] = np.sort(row)[::-1][:k]
indices[rowid] = row.argsort()[-k:] indices[rowid] = row.argsort()[::-1][:k]
self.outputs = {'Out': output, 'Indices': indices} self.outputs = {'Out': output, 'Indices': indices}
def set_args(self):
self.row = 32
self.top_k = 1
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -50,14 +55,39 @@ class TestTopkOp3d(OpTest): ...@@ -50,14 +55,39 @@ class TestTopkOp3d(OpTest):
output = np.ndarray((64, k)) output = np.ndarray((64, k))
indices = np.ndarray((64, k)).astype("int64") indices = np.ndarray((64, k)).astype("int64")
# FIXME: should use 'X': input for a 3d input self.inputs = {'X': input}
self.inputs = {'X': input_flat_2d}
self.attrs = {'k': k} self.attrs = {'k': k}
for rowid in range(64): for rowid in range(64):
row = input_flat_2d[rowid] row = input_flat_2d[rowid]
output[rowid] = np.sort(row)[-k:] output[rowid] = np.sort(row)[::-1][:k]
indices[rowid] = row.argsort()[-k:] indices[rowid] = row.argsort()[::-1][:k]
self.outputs = {
'Out': output.reshape((32, 2, k)),
'Indices': indices.reshape((32, 2, k))
}
def test_check_output(self):
self.check_output()
class TestTopkOp2(OpTest):
def setUp(self):
self.op_type = "top_k"
k = 1
m = 2056
input = np.random.random((m, 84)).astype("float32")
output = np.ndarray((m, k))
indices = np.ndarray((m, k)).astype("int64")
self.inputs = {'X': input}
self.attrs = {'k': k}
for rowid in range(m):
row = input[rowid]
output[rowid] = -np.sort(-row)[:k]
indices[rowid] = (-row).argsort()[:k]
self.outputs = {'Out': output, 'Indices': indices} self.outputs = {'Out': output, 'Indices': indices}
...@@ -65,5 +95,17 @@ class TestTopkOp3d(OpTest): ...@@ -65,5 +95,17 @@ class TestTopkOp3d(OpTest):
self.check_output() self.check_output()
class TestTopkOp3(TestTopkOp):
def set_args(self):
self.row = 2056
self.top_k = 3
class TestTopkOp4(TestTopkOp):
def set_args(self):
self.row = 40000
self.top_k = 1
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册