未验证 提交 d8bfe83d 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

add the extra for op rnn/sequence_conv/sequence_pool/sequence_softmax (#35554)

上级 47d15a30
...@@ -162,8 +162,10 @@ class RNNOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -162,8 +162,10 @@ class RNNOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::string>( AddAttr<std::string>(
"mode", "mode",
"(string) rnn types, including: LSTM, GRU, RNN_RELU, RNN_TANH."); "(string) rnn types, including: LSTM, GRU, RNN_RELU, RNN_TANH.");
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(0); AddAttr<int>("seed", "seed to used if fix_seed is True").SetDefault(0);
AddAttr<bool>("is_test", "True if in test phase.")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
)DOC"); )DOC");
} }
......
...@@ -230,7 +230,7 @@ template <typename T> ...@@ -230,7 +230,7 @@ template <typename T>
void dropout_cpu_function_inplace(const framework::ExecutionContext& context, void dropout_cpu_function_inplace(const framework::ExecutionContext& context,
Tensor* x, Tensor* y, Tensor* mask, Tensor* x, Tensor* y, Tensor* mask,
const float& dropout_prob, const float& dropout_prob,
const int& seed_number, const bool& is_test, const int& seed_number, bool is_test,
bool* is_has_reset) { bool* is_has_reset) {
if (is_test) { if (is_test) {
return; return;
...@@ -816,7 +816,7 @@ void RnnFunc(const framework::ExecutionContext& ctx, const Tensor* input, ...@@ -816,7 +816,7 @@ void RnnFunc(const framework::ExecutionContext& ctx, const Tensor* input,
Tensor* dropout_mask, const int& num_layers, const int& gate_num, Tensor* dropout_mask, const int& num_layers, const int& gate_num,
const int& input_size, const int& hidden_size, const int& input_size, const int& hidden_size,
const bool& is_bidirec, const std::string& cell_type, const bool& is_bidirec, const std::string& cell_type,
const float& dropout_prob, const bool& is_test, const int& seed, const float& dropout_prob, bool is_test, const int& seed,
Tensor* reserve_data) { Tensor* reserve_data) {
const int& direction_num = is_bidirec ? 2 : 1; const int& direction_num = is_bidirec ? 2 : 1;
const auto& init_h_dims = init_h->dims(); const auto& init_h_dims = init_h->dims();
...@@ -952,8 +952,8 @@ class RNNCPUKernel : public framework::OpKernel<T> { ...@@ -952,8 +952,8 @@ class RNNCPUKernel : public framework::OpKernel<T> {
const int& hidden_size = ctx.Attr<int>("hidden_size"); const int& hidden_size = ctx.Attr<int>("hidden_size");
const float& dropout_prob = ctx.Attr<float>("dropout_prob"); const float& dropout_prob = ctx.Attr<float>("dropout_prob");
const std::string& mode = ctx.Attr<std::string>("mode"); const std::string& mode = ctx.Attr<std::string>("mode");
const bool& is_test = ctx.Attr<bool>("is_test");
const int& seed = ctx.Attr<int>("seed"); const int& seed = ctx.Attr<int>("seed");
bool is_test = ctx.HasAttr("is_test") ? ctx.Attr<bool>("is_test") : false;
bool has_seq_length = ctx.HasInput("SequenceLength"); bool has_seq_length = ctx.HasInput("SequenceLength");
const Tensor* sequence_length = nullptr; const Tensor* sequence_length = nullptr;
...@@ -1809,7 +1809,8 @@ void RnnGradFunc(const framework::ExecutionContext& context, ...@@ -1809,7 +1809,8 @@ void RnnGradFunc(const framework::ExecutionContext& context,
const int& num_layers = context.Attr<int>("num_layers"); const int& num_layers = context.Attr<int>("num_layers");
const bool& is_bidirec = context.Attr<bool>("is_bidirec"); const bool& is_bidirec = context.Attr<bool>("is_bidirec");
const float& dropout_prob = context.Attr<float>("dropout_prob"); const float& dropout_prob = context.Attr<float>("dropout_prob");
const bool& is_test = context.Attr<bool>("is_test"); bool is_test =
context.HasAttr("is_test") ? context.Attr<bool>("is_test") : false;
// get the input_size, batch_size, time_step, hidden_size // get the input_size, batch_size, time_step, hidden_size
const int& time_step = input->dims()[0]; const int& time_step = input->dims()[0];
......
...@@ -61,7 +61,8 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -61,7 +61,8 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
.SetDefault(false); .SetDefault(false)
.AsExtra();
AddAttr<std::string>( AddAttr<std::string>(
"pooltype", "pooltype",
"(string, default 'AVERAGE') the pooling pooltype of SequencePoolOp.") "(string, default 'AVERAGE') the pooling pooltype of SequencePoolOp.")
......
...@@ -67,7 +67,8 @@ class SequencePoolKernel : public framework::OpKernel<T> { ...@@ -67,7 +67,8 @@ class SequencePoolKernel : public framework::OpKernel<T> {
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
Tensor* index = nullptr; Tensor* index = nullptr;
const bool is_test = context.Attr<bool>("is_test"); bool is_test =
context.HasAttr("is_test") ? context.Attr<bool>("is_test") : false;
// Do not create index buffer for inference (is_test) mode // Do not create index buffer for inference (is_test) mode
// TODO(jczaja): Skip index buffer creation for other devices eg. GPU // TODO(jczaja): Skip index buffer creation for other devices eg. GPU
......
...@@ -34,7 +34,8 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { ...@@ -34,7 +34,8 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported. // choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn =
ctx.HasAttr("use_cudnn") ? ctx.Attr<bool>("use_cudnn") : false;
bool runtime_cudnn_support = false; bool runtime_cudnn_support = false;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
...@@ -47,7 +48,9 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { ...@@ -47,7 +48,9 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
if (use_cudnn && runtime_cudnn_support) { if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.HasAttr("data_format")
? ctx.Attr<std::string>("data_format")
: "AnyLayout";
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_); framework::StringToDataLayout(data_format), library_);
...@@ -66,14 +69,16 @@ class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,14 +69,16 @@ class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>( AddAttr<bool>(
"use_cudnn", "use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn") "(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false); .SetDefault(false)
.AsExtra();
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
"(string, default NCHW) Only used in " "(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". " "An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, " "Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ") "the input will be transformed automatically. ")
.SetDefault("AnyLayout"); .SetDefault("AnyLayout")
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
Sequence Softmax Operator. Sequence Softmax Operator.
...@@ -130,7 +135,8 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { ...@@ -130,7 +135,8 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
// choose cudnn kernel if the runtime supported. // choose cudnn kernel if the runtime supported.
bool use_cudnn = ctx.Attr<bool>("use_cudnn"); bool use_cudnn =
ctx.HasAttr("use_cudnn") ? ctx.Attr<bool>("use_cudnn") : false;
bool runtime_cudnn_support = false; bool runtime_cudnn_support = false;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
...@@ -143,7 +149,9 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { ...@@ -143,7 +149,9 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
if (use_cudnn && runtime_cudnn_support) { if (use_cudnn && runtime_cudnn_support) {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.HasAttr("data_format")
? ctx.Attr<std::string>("data_format")
: "AnyLayout";
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace(), OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_); framework::StringToDataLayout(data_format), library_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册