From ab60fd8b6b670c7b65f0b8682d65f28fd279bf4c Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Mon, 10 Oct 2022 17:57:38 +0800 Subject: [PATCH] [PHI]Add RNN yaml (#46812) * add yaml entry for rnn and rrnn_grad, move infershape function for rnn_grad to phi infer meta * WIP: move rnn kernrl to phi * Change the code generation to avoid converting from intializer list to tuple of heterogeneous types. This is only triggered when an api has intermediate outputs, and the result of the outputs are of heterogeneous types. * fix the bug that when none in a vector of tensors requires gradient, the conversion to InferShapeContext to InferMetaContext (a.k.a. BuildInferMetaContext) produces errorous results. * fix ci bugs * fix ci bugs * fix ci bugs * modify code according comment Co-authored-by: chenfeiyu --- paddle/fluid/framework/infershape_utils.cc | 21 +- .../new_executor/new_executor_defs.cc | 9 +- paddle/fluid/framework/op_desc.cc | 9 +- paddle/fluid/framework/operator.cc | 9 +- paddle/fluid/imperative/infer_shape_context.h | 17 +- paddle/fluid/operators/rnn_op.cc | 29 +-- paddle/phi/api/yaml/generator/api_base.py | 238 ++++++++++++------ paddle/phi/api/yaml/generator/api_gen.py | 34 ++- paddle/phi/api/yaml/legacy_backward.yaml | 12 + paddle/phi/api/yaml/legacy_ops.yaml | 15 ++ paddle/phi/infermeta/backward.cc | 27 ++ paddle/phi/infermeta/backward.h | 7 + paddle/phi/infermeta/multiary.cc | 8 + .../fluid/tests/unittests/test_rnn_op.py | 24 ++ python/paddle/nn/layer/rnn.py | 8 +- 15 files changed, 319 insertions(+), 148 deletions(-) diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index 92c1e50ff69..c7cd883cec4 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -31,6 +31,8 @@ limitations under the License. */ #include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/tensor_utils.h" +#include "glog/logging.h" + namespace paddle { namespace framework { @@ -270,6 +272,7 @@ void CompatMetaTensor::set_dims(const DDim& dims) { ValidCheck(*this); if (is_runtime_) { auto* var = PADDLE_GET(Variable*, var_); + if (var == nullptr) return; if (var->IsType()) { auto* tensor = var->GetMutable(); phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims; @@ -295,7 +298,9 @@ void CompatMetaTensor::set_dims(const DDim& dims) { } } else { auto* var = PADDLE_GET(VarDesc*, var_); - var->SetShape(vectorize(dims)); + if (var) { + var->SetShape(vectorize(dims)); + } } } @@ -303,6 +308,7 @@ void CompatMetaTensor::set_dtype(phi::DataType dtype) { ValidCheck(*this); if (is_runtime_) { auto* var = PADDLE_GET(Variable*, var_); + if (var == nullptr) return; if (var->IsType()) { auto* tensor = var->GetMutable(); phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype; @@ -321,7 +327,9 @@ void CompatMetaTensor::set_dtype(phi::DataType dtype) { } } else { auto* var = PADDLE_GET(VarDesc*, var_); - var->SetDataType(paddle::framework::TransToProtoVarType(dtype)); + if (var) { + var->SetDataType(paddle::framework::TransToProtoVarType(dtype)); + } } } @@ -329,6 +337,7 @@ void CompatMetaTensor::set_layout(DataLayout layout) { ValidCheck(*this); if (is_runtime_) { auto* var = PADDLE_GET(Variable*, var_); + if (var == nullptr) return; if (var->IsType()) { auto* tensor = var->GetMutable(); phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout; @@ -357,6 +366,7 @@ void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) { ValidCheck(meta_tensor); if (is_runtime_) { auto* var = PADDLE_GET(Variable*, var_); + if (var == nullptr) return; if (var->IsType() && meta_tensor.is_dense()) { auto* tensor = var->GetMutable(); phi::DenseTensorUtils::GetMutableMeta(tensor)->lod = @@ -371,8 +381,10 @@ void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) { VLOG(3) << "input metatensor is not LoDTensor or LoDTensorArray."; return; } - var->SetLoDLevel( - static_cast(meta_tensor).GetCompileTimeLoD()); + if (var) { + var->SetLoDLevel(static_cast(meta_tensor) + .GetCompileTimeLoD()); + } } } @@ -382,6 +394,7 @@ void CompatMetaTensor::share_dims(const MetaTensor& meta_tensor) { set_dims(meta_tensor.dims()); if (is_runtime_) { auto* var = PADDLE_GET(Variable*, var_); + if (var == nullptr) return; if (var->IsType()) { auto* selected_rows = var->GetMutable(); auto& input_selected_rows = diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index 643d6b78c98..1ec0490b641 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -94,17 +94,12 @@ bool InterpretercoreInferShapeContext::HasOutputs(const std::string& name, if (it == outs.end() || it->second.empty()) { return false; } - if (allow_null) { - for (auto& output : it->second) { - if (output != nullptr) return true; - } - return false; - } else { + if (!allow_null) { for (auto& output : it->second) { if (output == nullptr) return false; } - return true; } + return true; } AttrReader InterpretercoreInferShapeContext::Attrs() const { diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index f2474cda0a9..3042dfa00d6 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -1227,17 +1227,12 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name, if (output_names.empty()) { return false; } - if (allow_null) { - for (auto &output : output_names) { - if (block_.HasVarRecursive(output)) return true; - } - return false; - } else { + if (!allow_null) { for (auto &output : output_names) { if (!block_.HasVarRecursive(output)) return false; } - return true; } + return true; } AttrReader CompileTimeInferShapeContext::Attrs() const { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index f6988f002c5..b1f632d7f2e 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -755,17 +755,12 @@ class RuntimeInferShapeContext : public InferShapeContext { if (it == outs.end() || it->second.empty()) { return false; } - if (allow_null) { - for (auto& output : it->second) { - if (output != nullptr) return true; - } - return false; - } else { + if (!allow_null) { for (auto& output : it->second) { if (output == nullptr) return false; } - return true; } + return true; } AttrReader Attrs() const override { diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index 5702bcfca73..93efc0d7021 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -109,21 +109,14 @@ class DygraphInferShapeContext : public framework::InferShapeContext { if (it == var_map_out_->end() || it->second.empty()) { return false; } - if (allow_null) { - for (auto& output : it->second) { - if (output != nullptr) { - return true; - } - } - return false; - } else { + if (!allow_null) { for (auto& output : it->second) { if (output == nullptr) { return false; } } - return true; } + return true; } framework::AttrReader Attrs() const override { @@ -288,7 +281,11 @@ class DygraphInferShapeContext : public framework::InferShapeContext { var_map_out_->end(), platform::errors::NotFound("Can not find [%s] in outputs.", name)); for (auto& var : it->second) { - res.emplace_back(var->MutableVar()); + if (var) { + res.emplace_back(var->MutableVar()); + } else { + res.emplace_back(framework::InferShapeVarPtr()); + } } return res; } diff --git a/paddle/fluid/operators/rnn_op.cc b/paddle/fluid/operators/rnn_op.cc index aba720a99ba..3528cc957fa 100644 --- a/paddle/fluid/operators/rnn_op.cc +++ b/paddle/fluid/operators/rnn_op.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/multiary.h" namespace paddle { @@ -115,29 +116,6 @@ class RNNGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "RNN"); - OP_INOUT_CHECK(ctx->HasInputs("PreState"), "Input", "PreState", "RNN"); - OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "RNN"); - // OP_INOUT_CHECK(ctx->HasInputs("State"), "Input", "State", "RNN"); - - auto SetOutGradDim = [&ctx](const std::string& name) { - auto g_name = framework::GradVarName(name); - if (ctx->HasOutput(g_name)) { - ctx->SetOutputDim(g_name, ctx->GetInputDim(name)); - } - }; - - SetOutGradDim("Input"); - if (ctx->HasOutputs(framework::GradVarName("WeightList"))) { - ctx->SetOutputsDim(framework::GradVarName("WeightList"), - ctx->GetInputsDim("WeightList")); - } - if (ctx->HasOutputs(framework::GradVarName("PreState"))) { - ctx->SetOutputsDim(framework::GradVarName("PreState"), - ctx->GetInputsDim("PreState")); - } - } framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( @@ -192,6 +170,9 @@ namespace ops = paddle::operators; DECLARE_INFER_SHAPE_FUNCTOR(rnn, RnnInferShapeFunctor, PD_INFER_META(phi::RnnInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(rnn_grad, + RnnGradInferShapeFunctor, + PD_INFER_META(phi::RnnGradInferMeta)); REGISTER_OPERATOR(rnn, ops::RNNOp, @@ -199,4 +180,4 @@ REGISTER_OPERATOR(rnn, ops::RNNGradOpMaker, ops::RNNGradOpMaker, RnnInferShapeFunctor); -REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp); +REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp, RnnGradInferShapeFunctor); diff --git a/paddle/phi/api/yaml/generator/api_base.py b/paddle/phi/api/yaml/generator/api_base.py index 26cb5c6de3e..534158eb31c 100644 --- a/paddle/phi/api/yaml/generator/api_base.py +++ b/paddle/phi/api/yaml/generator/api_base.py @@ -49,6 +49,23 @@ class BaseAPI(object): self.data_transform = self.parse_data_transform(api_item_yaml) self.inplace_map, self.view_map = {}, {} + self.gene_input_func = { + "const Tensor&": { + "dense": self.gene_dense_input, + "selected_rows": self.gene_selected_rows_input + }, + "const paddle::optional&": { + "dense": self.gene_dense_input, + "selected_rows": self.gene_selected_rows_input + }, + "const std::vector&": { + "dense": self.gene_vec_dense_input + }, + "const paddle::optional>&": { + "dense": self.gene_optional_vec_dense_input + } + } + def get_api_name(self, api_item_yaml): return api_item_yaml['op'] @@ -550,66 +567,71 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d {code_indent} phi::{infer_meta['func']}({param_code}); """ - def get_kernel_args(self, kernel_tensor_type=None, code_indent=''): - dense_input_trans_map = { - 'const Tensor&': - 'const phi::DenseTensor&', - 'const std::vector&': - 'const std::vector&', - 'const paddle::optional': - 'paddle::optional', - 'const paddle::optional&': - 'const paddle::optional&', - 'const paddle::optional>&': - 'const paddle::optional>&' - } - dense_out_trans_map = { - 'Tensor': 'phi::DenseTensor*', - 'std::vector': 'std::vector&' - } - sr_input_trans_map = { - 'const Tensor&': - 'const phi::SelectedRows&', - 'const paddle::optional&': - 'const paddle::optional&' - } - sr_out_trans_map = {'Tensor': 'phi::SelectedRows*'} + def gene_trans_flag(self, input_name): + trans_flag = "{}" + if input_name in self.data_transform['skip_transform']: + trans_flag = "{true}" + elif input_name in self.data_transform['support_trans_dtype']: + trans_flag = "{false, true}" + return trans_flag + + def gene_dense_input(self, + input_name, + input_name_tensor_map, + code_indent=''): + input_tensor_code = "" + trans_flag = self.gene_trans_flag(input_name) input_names = self.inputs['names'] - input_infos = self.inputs['input_info'] - kernel_args_type_list = ['const platform::DeviceContext&'] + attr_names = self.attrs['names'] + kernel_param = self.kernel['param'] + if kernel_param is None: + kernel_param = input_names + attr_names + + input_name_tensor_map[input_name].append( + (f"{PREFIX_TENSOR_NAME}{input_name}", False)) + input_tensor_code = input_tensor_code + f""" +{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});""" + return input_tensor_code + def gene_selected_rows_input(self, + input_name, + input_name_tensor_map, + code_indent=''): + input_tensor_code = "" + trans_flag = self.gene_trans_flag(input_name) + input_names = self.inputs['names'] attr_names = self.attrs['names'] kernel_param = self.kernel['param'] if kernel_param is None: kernel_param = input_names + attr_names + input_name_tensor_map[input_name].append( + (f"{PREFIX_TENSOR_NAME}{input_name}", False)) + input_tensor_code = input_tensor_code + f""" +{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name}); +""" + return input_tensor_code + + def gene_optional_vec_dense_input(self, + input_name, + input_name_tensor_map, + code_indent=''): input_tensor_code = "" - input_name_tensor_map = collections.defaultdict(list) - for i, input_name in enumerate(input_names): - # set input code - if input_name in kernel_param: - # input is dense tensor - if kernel_tensor_type is None or kernel_tensor_type[0][ - kernel_param.index(input_name)] == 'dense': - trans_flag = "{}" - if input_name in self.data_transform['skip_transform']: - trans_flag = "{true}" - elif input_name in self.data_transform[ - 'support_trans_dtype']: - trans_flag = "{false, true}" - if input_name in self.optional_vars: - if self.inputs['input_info'][ - input_name] == "const paddle::optional>&": - if input_name in self.inplace_map.values(): - input_name_tensor_map[input_name].append( - (f"{PREFIX_TENSOR_NAME}{input_name}", True)) - input_tensor_code = input_tensor_code + f""" + trans_flag = self.gene_trans_flag(input_name) + input_names = self.inputs['names'] + attr_names = self.attrs['names'] + kernel_param = self.kernel['param'] + if kernel_param is None: + kernel_param = input_names + attr_names + if input_name in self.inplace_map.values(): + input_name_tensor_map[input_name].append( + (f"{PREFIX_TENSOR_NAME}{input_name}", True)) + input_tensor_code = input_tensor_code + f""" {code_indent} paddle::optional> {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name});""" - else: - input_name_tensor_map[input_name].append( - (f"{PREFIX_TENSOR_NAME}{input_name}_vec", - True)) - input_tensor_code = input_tensor_code + f""" + else: + input_name_tensor_map[input_name].append( + (f"{PREFIX_TENSOR_NAME}{input_name}_vec", True)) + input_tensor_code = input_tensor_code + f""" {code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag}); {code_indent} paddle::optional> {PREFIX_TENSOR_NAME}{input_name}; {code_indent} if ({PREFIX_TENSOR_NAME}{input_name}_vec){{ @@ -618,47 +640,58 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d {code_indent} {PREFIX_TENSOR_NAME}{input_name}->at(i) = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i); {code_indent} }} {code_indent} }}""" - else: - input_name_tensor_map[input_name].append( - (f"{PREFIX_TENSOR_NAME}{input_name}", False)) - input_tensor_code = input_tensor_code + f""" -{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});""" + return input_tensor_code - else: - if self.inputs['input_info'][ - input_name] == "const Tensor&": - input_name_tensor_map[input_name].append( - (f"{PREFIX_TENSOR_NAME}{input_name}", False)) - input_tensor_code = input_tensor_code + f""" -{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});""" + def gene_vec_dense_input(self, + input_name, + input_name_tensor_map, + code_indent=''): + input_tensor_code = "" + trans_flag = self.gene_trans_flag(input_name) + input_names = self.inputs['names'] + attr_names = self.attrs['names'] + kernel_param = self.kernel['param'] + if kernel_param is None: + kernel_param = input_names + attr_names - elif self.inputs['input_info'][ - input_name] == "const std::vector&": - if input_name in self.inplace_map.values(): - input_name_tensor_map[input_name].append( - (f"{PREFIX_TENSOR_NAME}{input_name}", True)) - input_tensor_code = input_tensor_code + f""" + if input_name in self.inplace_map.values(): + input_name_tensor_map[input_name].append( + (f"{PREFIX_TENSOR_NAME}{input_name}", True)) + input_tensor_code = input_tensor_code + f""" {code_indent} std::vector {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name});""" - else: - input_name_tensor_map[input_name].append( - (f"{PREFIX_TENSOR_NAME}{input_name}_vec", - True)) - input_tensor_code = input_tensor_code + f""" + else: + input_name_tensor_map[input_name].append( + (f"{PREFIX_TENSOR_NAME}{input_name}_vec", True)) + input_tensor_code = input_tensor_code + f""" {code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag}); {code_indent} std::vector {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size()); {code_indent} for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{ {code_indent} {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i); {code_indent} }}""" + return input_tensor_code - else: - # do nothing - pass - else: # input is selected_rows - input_name_tensor_map[input_name].append( - (f"{PREFIX_TENSOR_NAME}{input_name}", False)) - input_tensor_code = input_tensor_code + f""" -{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name}); -""" + def gene_input(self, kernel_tensor_type=None, code_indent=''): + input_names = self.inputs['names'] + attr_names = self.attrs['names'] + kernel_param = self.kernel['param'] + if kernel_param is None: + kernel_param = input_names + attr_names + input_name_tensor_map = collections.defaultdict(list) + input_tensor_code = "" + for i, input_name in enumerate(input_names): + # set input code + if input_name in kernel_param: + # input is dense tensor + api_tensor_type = self.inputs['input_info'][input_name] + phi_tensor_type = 'dense' if kernel_tensor_type is None else kernel_tensor_type[ + 0][kernel_param.index(input_name)] + if api_tensor_type in self.gene_input_func.keys(): + input_tensor_code += self.gene_input_func[api_tensor_type][ + phi_tensor_type](input_name, input_name_tensor_map, + code_indent) + else: + # do nothing + pass else: if input_name in self.infer_meta['param']: if input_name in self.optional_vars: @@ -674,6 +707,45 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d else: input_tensor_code = input_tensor_code + f""" {code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = {input_name}.impl();""" + + return input_name_tensor_map, input_tensor_code + + def get_kernel_args(self, kernel_tensor_type=None, code_indent=''): + dense_input_trans_map = { + 'const Tensor&': + 'const phi::DenseTensor&', + 'const std::vector&': + 'const std::vector&', + 'const paddle::optional': + 'paddle::optional', + 'const paddle::optional&': + 'const paddle::optional&', + 'const paddle::optional>&': + 'const paddle::optional>&' + } + dense_out_trans_map = { + 'Tensor': 'phi::DenseTensor*', + 'std::vector': 'std::vector&' + } + sr_input_trans_map = { + 'const Tensor&': + 'const phi::SelectedRows&', + 'const paddle::optional&': + 'const paddle::optional&' + } + sr_out_trans_map = {'Tensor': 'phi::SelectedRows*'} + input_names = self.inputs['names'] + input_infos = self.inputs['input_info'] + kernel_args_type_list = ['const platform::DeviceContext&'] + + attr_names = self.attrs['names'] + kernel_param = self.kernel['param'] + if kernel_param is None: + kernel_param = input_names + attr_names + + input_name_tensor_map, input_tensor_code = self.gene_input( + kernel_tensor_type, code_indent) + input_tensor_code = input_tensor_code + f""" {code_indent} if(platform::RecordOpInfoSupplement::IsEnabled()){{""" single_tensor_names = [] diff --git a/paddle/phi/api/yaml/generator/api_gen.py b/paddle/phi/api/yaml/generator/api_gen.py index 80a9d586ca3..3add9ed13b7 100644 --- a/paddle/phi/api/yaml/generator/api_gen.py +++ b/paddle/phi/api/yaml/generator/api_gen.py @@ -45,6 +45,26 @@ class ForwardAPI(BaseAPI): else: return self.api + def gene_input(self, kernel_tensor_type=None, code_indent=''): + kernel_param = self.kernel['param'] + input_name_tensor_map, input_tensor_code = super().gene_input( + kernel_tensor_type, code_indent) + + # generate the input that is in view list + for i, input_name in enumerate(self.inputs['names']): + if input_name in self.view_map.values( + ) and input_name not in input_name_tensor_map.keys(): + if kernel_tensor_type is None or kernel_tensor_type[0][ + kernel_param.index(input_name)] == 'dense': + trans_flag = self.gene_trans_flag(input_name) + input_tensor_code = input_tensor_code + f""" +{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt(0), {trans_flag});""" + else: + # do nothing + pass + + return input_name_tensor_map, input_tensor_code + def parse_intermediate(self, api_item_yaml): if 'intermediate' in api_item_yaml: intermediate_outs = [ @@ -215,11 +235,15 @@ class ForwardAPI(BaseAPI): if not inplace_flag and self.view_map is not None and self.outputs[ 'names'][i] in self.view_map: - output_create = output_create + f""" -{code_indent} kernel_out_{i}->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]}); -{code_indent} kernel_out_{i}->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]}); -{code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";""" - + if out_dtype_list[i] == 'Tensor': + output_create = output_create + f""" + {code_indent} kernel_out_{i}->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]}); + {code_indent} kernel_out_{i}->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]}); + {code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";""" + else: + raise ValueError( + "{} : Output error: only support Tensor type when use view in yaml. But get {}" + .format(self.api, out_dtype_list[i])) else: raise ValueError( "{} : Output error: the output should not be empty.".format( diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 2f4eff98356..5654d037470 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1928,6 +1928,18 @@ output : Tensor(x_grad) invoke : reverse(out_grad, axis) +- backward_op : rnn_grad + forward : rnn (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor dropout_state_in, float dropout_prob, bool is_bidirec, int input_size, int hidden_size, int num_layers, str mode, int seed, bool is_test) -> Tensor(out), Tensor(dropout_state_out), Tensor[](state), Tensor(reserve) + args : (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor out, Tensor dropout_state_out, Tensor reserve, Tensor out_grad, Tensor[] state_grad, float dropout_prob, bool is_bidirec, int input_size, int hidden_size, int num_layers, str mode, int seed, bool is_test) + output : Tensor(x_grad), Tensor[](pre_state_grad){pre_state.size()}, Tensor[](weight_list_grad){weight_list.size()} + infer_meta : + func : RnnGradInferMeta + param : [x, pre_state, weight_list] + kernel : + func : rnn_grad + data_type: out_grad + optional : sequence_length + - backward_op : roi_align_grad forward : roi_align (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, bool aligned) -> Tensor(out) args : (Tensor x, Tensor boxes, Tensor boxes_num, Tensor out_grad, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, bool aligned) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index bca2aeb58a5..2a43d30b5d5 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -2991,6 +2991,21 @@ func: overlap_add backward: overlap_add_grad +- op: rnn + args: (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor dropout_state_in, float dropout_prob=0.0, bool is_bidirec=false, int input_size=10, int hidden_size=100, int num_layers=1, str mode="RNN_TANH", int seed=0, bool is_test=false) + output: Tensor(out), Tensor(dropout_state_out), Tensor[](state){pre_state.size()}, Tensor(reserve) + infer_meta: + func: RnnInferMeta + param : [x, pre_state, weight_list, sequence_length, dropout_prob, is_bidirec, input_size, hidden_size, num_layers, mode, seed, is_test] + kernel: + func: rnn + param : [x, pre_state, weight_list, sequence_length, dropout_prob, is_bidirec, input_size, hidden_size, num_layers, mode, seed, is_test] + data_type: x + backward: rnn_grad + optional : sequence_length + intermediate : reserve + view : (dropout_state_in -> dropout_state_out) + - op: uniform_random_inplace args: (Tensor x, float min, float max, int seed, int diag_num, int diag_step, float diag_val) output: Tensor(out) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 9f7eba86b8f..fd179a754a2 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -807,6 +807,33 @@ void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad, } } +void RnnGradInferMeta(const MetaTensor& x, + const std::vector& pre_state, + const std::vector& weight_list, + MetaTensor* x_grad, + std::vector pre_state_grad, + std::vector weight_grad_list) { + PADDLE_ENFORCE_GT( + pre_state.size(), + 0UL, + phi::errors::InvalidArgument( + "The input pre_state in RnnGradInferMeta can't be empty.")); + PADDLE_ENFORCE_GT( + weight_grad_list.size(), + 0UL, + phi::errors::InvalidArgument( + "The input weight_grad_list in RnnGradInferMeta can't be empty.")); + if (x_grad) { + UnchangedInferMeta(x, x_grad); + } + if (pre_state_grad.size()) { + UnchangedMultiInferMeta(pre_state, pre_state_grad); + } + if (weight_grad_list.size()) { + UnchangedMultiInferMeta(weight_list, weight_grad_list); + } +} + void ScatterGradInferMeta(const MetaTensor& index, const MetaTensor& updates, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 3e7cfa3ad83..38372af4b30 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -339,6 +339,13 @@ void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad, const MetaTensor& x_grad_grad, MetaTensor* out_grad_grad); +void RnnGradInferMeta(const MetaTensor& x, + const std::vector& pre_state, + const std::vector& weight_list, + MetaTensor* x_grad, + std::vector pre_state_grad, + std::vector weight_grad_list); + void ScatterGradInferMeta(const MetaTensor& index, const MetaTensor& updates, const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index e0e7299e8a8..5868d974122 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2496,6 +2496,14 @@ void StackInferMeta(const std::vector& x, void UnchangedMultiInferMeta(const std::vector& x, std::vector out) { + PADDLE_ENFORCE_EQ( + x.size(), + out.size(), + phi::errors::InvalidArgument( + "Input's size should be equal to the output's size" + "but received input size: (%d) does not equals output_size: (%d)", + x.size(), + out.size())); for (size_t i = 0; i < x.size(); ++i) { if (out[i]) { out[i]->share_meta(*x[i]); diff --git a/python/paddle/fluid/tests/unittests/test_rnn_op.py b/python/paddle/fluid/tests/unittests/test_rnn_op.py index e7e99c1bf5b..3bf6701e259 100644 --- a/python/paddle/fluid/tests/unittests/test_rnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_rnn_op.py @@ -142,6 +142,30 @@ class TestRNNOp(OpTest): self.check_grad(set(grad_check_list), ['Out', 'last_hidden', 'last_cell']) + def test_grad_only_input(self): + if not self.is_test: + var_name_list = self.get_weight_names() + grad_check_list = ['Input'] + grad_check_list.extend(var_name_list) + self.check_grad(set(grad_check_list), + ['Out', 'last_hidden', 'last_cell']) + + def test_grad_only_h(self): + if not self.is_test: + var_name_list = self.get_weight_names() + grad_check_list = ['init_h'] + grad_check_list.extend(var_name_list) + self.check_grad(set(grad_check_list), + ['Out', 'last_hidden', 'last_cell']) + + def test_grad_only_c(self): + if not self.is_test: + var_name_list = self.get_weight_names() + grad_check_list = ['init_c'] + grad_check_list.extend(var_name_list) + self.check_grad(set(grad_check_list), + ['Out', 'last_hidden', 'last_cell']) + class TestRNNOp1(TestRNNOp): diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index d48219fee48..a927258bf7b 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -1012,7 +1012,13 @@ class RNNBase(LayerList): if not self.time_major: inputs = paddle.tensor.transpose(inputs, [1, 0, 2]) - if in_dynamic_mode(): + if in_dygraph_mode(): + out, _, state = _C_ops.rnn( + inputs, initial_states, self._all_weights, sequence_length, + self._dropout_state, self.dropout, self.num_directions == 2, + self.input_size, self.hidden_size, self.num_layers, self.mode, + 0, not self.training) + elif in_dynamic_mode(): _, _, out, state = _legacy_C_ops.rnn( inputs, initial_states, self._all_weights, sequence_length, self._dropout_state, self.state_components, 'dropout_prob', -- GitLab