未验证 提交 ab60fd8b 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Add RNN yaml (#46812)

* add yaml entry for rnn and rrnn_grad, move infershape function for rnn_grad to phi infer meta

* WIP: move rnn kernrl to phi

* Change the code generation to avoid converting from intializer list to tuple of heterogeneous types.
This is only triggered when an api has intermediate outputs, and the result of the outputs are of heterogeneous types.

* fix the bug that when none in a vector of tensors requires gradient, the conversion to InferShapeContext to InferMetaContext (a.k.a. BuildInferMetaContext) produces errorous results.

* fix ci bugs

* fix ci bugs

* fix ci bugs

* modify code according comment
Co-authored-by: Nchenfeiyu <chenfeiyu@baidu.com>
上级 715b9d66
......@@ -31,6 +31,8 @@ limitations under the License. */
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/tensor_utils.h"
#include "glog/logging.h"
namespace paddle {
namespace framework {
......@@ -270,6 +272,7 @@ void CompatMetaTensor::set_dims(const DDim& dims) {
ValidCheck(*this);
if (is_runtime_) {
auto* var = PADDLE_GET(Variable*, var_);
if (var == nullptr) return;
if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
......@@ -295,14 +298,17 @@ void CompatMetaTensor::set_dims(const DDim& dims) {
}
} else {
auto* var = PADDLE_GET(VarDesc*, var_);
if (var) {
var->SetShape(vectorize(dims));
}
}
}
void CompatMetaTensor::set_dtype(phi::DataType dtype) {
ValidCheck(*this);
if (is_runtime_) {
auto* var = PADDLE_GET(Variable*, var_);
if (var == nullptr) return;
if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
......@@ -321,14 +327,17 @@ void CompatMetaTensor::set_dtype(phi::DataType dtype) {
}
} else {
auto* var = PADDLE_GET(VarDesc*, var_);
if (var) {
var->SetDataType(paddle::framework::TransToProtoVarType(dtype));
}
}
}
void CompatMetaTensor::set_layout(DataLayout layout) {
ValidCheck(*this);
if (is_runtime_) {
auto* var = PADDLE_GET(Variable*, var_);
if (var == nullptr) return;
if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
......@@ -357,6 +366,7 @@ void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
ValidCheck(meta_tensor);
if (is_runtime_) {
auto* var = PADDLE_GET(Variable*, var_);
if (var == nullptr) return;
if (var->IsType<phi::DenseTensor>() && meta_tensor.is_dense()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->lod =
......@@ -371,8 +381,10 @@ void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
VLOG(3) << "input metatensor is not LoDTensor or LoDTensorArray.";
return;
}
var->SetLoDLevel(
static_cast<const CompatMetaTensor&>(meta_tensor).GetCompileTimeLoD());
if (var) {
var->SetLoDLevel(static_cast<const CompatMetaTensor&>(meta_tensor)
.GetCompileTimeLoD());
}
}
}
......@@ -382,6 +394,7 @@ void CompatMetaTensor::share_dims(const MetaTensor& meta_tensor) {
set_dims(meta_tensor.dims());
if (is_runtime_) {
auto* var = PADDLE_GET(Variable*, var_);
if (var == nullptr) return;
if (var->IsType<phi::SelectedRows>()) {
auto* selected_rows = var->GetMutable<phi::SelectedRows>();
auto& input_selected_rows =
......
......@@ -94,17 +94,12 @@ bool InterpretercoreInferShapeContext::HasOutputs(const std::string& name,
if (it == outs.end() || it->second.empty()) {
return false;
}
if (allow_null) {
for (auto& output : it->second) {
if (output != nullptr) return true;
}
return false;
} else {
if (!allow_null) {
for (auto& output : it->second) {
if (output == nullptr) return false;
}
return true;
}
return true;
}
AttrReader InterpretercoreInferShapeContext::Attrs() const {
......
......@@ -1227,17 +1227,12 @@ bool CompileTimeInferShapeContext::HasOutputs(const std::string &name,
if (output_names.empty()) {
return false;
}
if (allow_null) {
for (auto &output : output_names) {
if (block_.HasVarRecursive(output)) return true;
}
return false;
} else {
if (!allow_null) {
for (auto &output : output_names) {
if (!block_.HasVarRecursive(output)) return false;
}
return true;
}
return true;
}
AttrReader CompileTimeInferShapeContext::Attrs() const {
......
......@@ -755,17 +755,12 @@ class RuntimeInferShapeContext : public InferShapeContext {
if (it == outs.end() || it->second.empty()) {
return false;
}
if (allow_null) {
for (auto& output : it->second) {
if (output != nullptr) return true;
}
return false;
} else {
if (!allow_null) {
for (auto& output : it->second) {
if (output == nullptr) return false;
}
return true;
}
return true;
}
AttrReader Attrs() const override {
......
......@@ -109,21 +109,14 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
if (it == var_map_out_->end() || it->second.empty()) {
return false;
}
if (allow_null) {
for (auto& output : it->second) {
if (output != nullptr) {
return true;
}
}
return false;
} else {
if (!allow_null) {
for (auto& output : it->second) {
if (output == nullptr) {
return false;
}
}
return true;
}
return true;
}
framework::AttrReader Attrs() const override {
......@@ -288,7 +281,11 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
var_map_out_->end(),
platform::errors::NotFound("Can not find [%s] in outputs.", name));
for (auto& var : it->second) {
if (var) {
res.emplace_back(var->MutableVar());
} else {
res.emplace_back(framework::InferShapeVarPtr());
}
}
return res;
}
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
......@@ -115,29 +116,6 @@ class RNNGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "RNN");
OP_INOUT_CHECK(ctx->HasInputs("PreState"), "Input", "PreState", "RNN");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "RNN");
// OP_INOUT_CHECK(ctx->HasInputs("State"), "Input", "State", "RNN");
auto SetOutGradDim = [&ctx](const std::string& name) {
auto g_name = framework::GradVarName(name);
if (ctx->HasOutput(g_name)) {
ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
}
};
SetOutGradDim("Input");
if (ctx->HasOutputs(framework::GradVarName("WeightList"))) {
ctx->SetOutputsDim(framework::GradVarName("WeightList"),
ctx->GetInputsDim("WeightList"));
}
if (ctx->HasOutputs(framework::GradVarName("PreState"))) {
ctx->SetOutputsDim(framework::GradVarName("PreState"),
ctx->GetInputsDim("PreState"));
}
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
......@@ -192,6 +170,9 @@ namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(rnn,
RnnInferShapeFunctor,
PD_INFER_META(phi::RnnInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(rnn_grad,
RnnGradInferShapeFunctor,
PD_INFER_META(phi::RnnGradInferMeta));
REGISTER_OPERATOR(rnn,
ops::RNNOp,
......@@ -199,4 +180,4 @@ REGISTER_OPERATOR(rnn,
ops::RNNGradOpMaker<paddle::framework::OpDesc>,
ops::RNNGradOpMaker<paddle::imperative::OpBase>,
RnnInferShapeFunctor);
REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp);
REGISTER_OPERATOR(rnn_grad, ops::RNNGradOp, RnnGradInferShapeFunctor);
......@@ -49,6 +49,23 @@ class BaseAPI(object):
self.data_transform = self.parse_data_transform(api_item_yaml)
self.inplace_map, self.view_map = {}, {}
self.gene_input_func = {
"const Tensor&": {
"dense": self.gene_dense_input,
"selected_rows": self.gene_selected_rows_input
},
"const paddle::optional<Tensor>&": {
"dense": self.gene_dense_input,
"selected_rows": self.gene_selected_rows_input
},
"const std::vector<Tensor>&": {
"dense": self.gene_vec_dense_input
},
"const paddle::optional<std::vector<Tensor>>&": {
"dense": self.gene_optional_vec_dense_input
}
}
def get_api_name(self, api_item_yaml):
return api_item_yaml['op']
......@@ -550,56 +567,62 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} phi::{infer_meta['func']}({param_code});
"""
def get_kernel_args(self, kernel_tensor_type=None, code_indent=''):
dense_input_trans_map = {
'const Tensor&':
'const phi::DenseTensor&',
'const std::vector<Tensor>&':
'const std::vector<const phi::DenseTensor*>&',
'const paddle::optional<Tensor&>':
'paddle::optional<const phi::DenseTensor&>',
'const paddle::optional<Tensor>&':
'const paddle::optional<phi::DenseTensor>&',
'const paddle::optional<std::vector<Tensor>>&':
'const paddle::optional<std::vector<const phi::DenseTensor*>>&'
}
dense_out_trans_map = {
'Tensor': 'phi::DenseTensor*',
'std::vector<Tensor>': 'std::vector<phi::DenseTensor*>&'
}
sr_input_trans_map = {
'const Tensor&':
'const phi::SelectedRows&',
'const paddle::optional<Tensor>&':
'const paddle::optional<phi::SelectedRows>&'
}
sr_out_trans_map = {'Tensor': 'phi::SelectedRows*'}
def gene_trans_flag(self, input_name):
trans_flag = "{}"
if input_name in self.data_transform['skip_transform']:
trans_flag = "{true}"
elif input_name in self.data_transform['support_trans_dtype']:
trans_flag = "{false, true}"
return trans_flag
def gene_dense_input(self,
input_name,
input_name_tensor_map,
code_indent=''):
input_tensor_code = ""
trans_flag = self.gene_trans_flag(input_name)
input_names = self.inputs['names']
input_infos = self.inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&']
attr_names = self.attrs['names']
kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
input_name_tensor_map[input_name].append(
(f"{PREFIX_TENSOR_NAME}{input_name}", False))
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});"""
return input_tensor_code
def gene_selected_rows_input(self,
input_name,
input_name_tensor_map,
code_indent=''):
input_tensor_code = ""
trans_flag = self.gene_trans_flag(input_name)
input_names = self.inputs['names']
attr_names = self.attrs['names']
kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
input_name_tensor_map[input_name].append(
(f"{PREFIX_TENSOR_NAME}{input_name}", False))
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});
"""
return input_tensor_code
def gene_optional_vec_dense_input(self,
input_name,
input_name_tensor_map,
code_indent=''):
input_tensor_code = ""
input_name_tensor_map = collections.defaultdict(list)
for i, input_name in enumerate(input_names):
# set input code
if input_name in kernel_param:
# input is dense tensor
if kernel_tensor_type is None or kernel_tensor_type[0][
kernel_param.index(input_name)] == 'dense':
trans_flag = "{}"
if input_name in self.data_transform['skip_transform']:
trans_flag = "{true}"
elif input_name in self.data_transform[
'support_trans_dtype']:
trans_flag = "{false, true}"
if input_name in self.optional_vars:
if self.inputs['input_info'][
input_name] == "const paddle::optional<std::vector<Tensor>>&":
trans_flag = self.gene_trans_flag(input_name)
input_names = self.inputs['names']
attr_names = self.attrs['names']
kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
if input_name in self.inplace_map.values():
input_name_tensor_map[input_name].append(
(f"{PREFIX_TENSOR_NAME}{input_name}", True))
......@@ -607,8 +630,7 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} paddle::optional<std::vector<const phi::DenseTensor*>> {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name});"""
else:
input_name_tensor_map[input_name].append(
(f"{PREFIX_TENSOR_NAME}{input_name}_vec",
True))
(f"{PREFIX_TENSOR_NAME}{input_name}_vec", True))
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});
{code_indent} paddle::optional<std::vector<const phi::DenseTensor*>> {PREFIX_TENSOR_NAME}{input_name};
......@@ -618,22 +640,20 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} {PREFIX_TENSOR_NAME}{input_name}->at(i) = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i);
{code_indent} }}
{code_indent} }}"""
else:
input_name_tensor_map[input_name].append(
(f"{PREFIX_TENSOR_NAME}{input_name}", False))
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});"""
return input_tensor_code
else:
if self.inputs['input_info'][
input_name] == "const Tensor&":
input_name_tensor_map[input_name].append(
(f"{PREFIX_TENSOR_NAME}{input_name}", False))
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});"""
def gene_vec_dense_input(self,
input_name,
input_name_tensor_map,
code_indent=''):
input_tensor_code = ""
trans_flag = self.gene_trans_flag(input_name)
input_names = self.inputs['names']
attr_names = self.attrs['names']
kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
elif self.inputs['input_info'][
input_name] == "const std::vector<Tensor>&":
if input_name in self.inplace_map.values():
input_name_tensor_map[input_name].append(
(f"{PREFIX_TENSOR_NAME}{input_name}", True))
......@@ -641,24 +661,37 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
{code_indent} std::vector<const phi::DenseTensor*> {PREFIX_TENSOR_NAME}{input_name} = TensorToConstDenseTensorPtr({input_name});"""
else:
input_name_tensor_map[input_name].append(
(f"{PREFIX_TENSOR_NAME}{input_name}_vec",
True))
(f"{PREFIX_TENSOR_NAME}{input_name}_vec", True))
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name}_vec = PrepareData({input_name}, kernel.InputAt({kernel_param.index(input_name)}), {trans_flag});
{code_indent} std::vector<const phi::DenseTensor*> {PREFIX_TENSOR_NAME}{input_name}({PREFIX_TENSOR_NAME}{input_name}_vec->size());
{code_indent} for (size_t i = 0; i < {PREFIX_TENSOR_NAME}{input_name}.size(); ++i) {{
{code_indent} {PREFIX_TENSOR_NAME}{input_name}[i] = &{PREFIX_TENSOR_NAME}{input_name}_vec->at(i);
{code_indent} }}"""
return input_tensor_code
def gene_input(self, kernel_tensor_type=None, code_indent=''):
input_names = self.inputs['names']
attr_names = self.attrs['names']
kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
input_name_tensor_map = collections.defaultdict(list)
input_tensor_code = ""
for i, input_name in enumerate(input_names):
# set input code
if input_name in kernel_param:
# input is dense tensor
api_tensor_type = self.inputs['input_info'][input_name]
phi_tensor_type = 'dense' if kernel_tensor_type is None else kernel_tensor_type[
0][kernel_param.index(input_name)]
if api_tensor_type in self.gene_input_func.keys():
input_tensor_code += self.gene_input_func[api_tensor_type][
phi_tensor_type](input_name, input_name_tensor_map,
code_indent)
else:
# do nothing
pass
else: # input is selected_rows
input_name_tensor_map[input_name].append(
(f"{PREFIX_TENSOR_NAME}{input_name}", False))
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = TensorToSelectedRows({input_name});
"""
else:
if input_name in self.infer_meta['param']:
if input_name in self.optional_vars:
......@@ -674,6 +707,45 @@ PADDLE_API {self.get_return_type(inplace_flag=True)} {api_func_name}({self.get_d
else:
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = {input_name}.impl();"""
return input_name_tensor_map, input_tensor_code
def get_kernel_args(self, kernel_tensor_type=None, code_indent=''):
dense_input_trans_map = {
'const Tensor&':
'const phi::DenseTensor&',
'const std::vector<Tensor>&':
'const std::vector<const phi::DenseTensor*>&',
'const paddle::optional<Tensor&>':
'paddle::optional<const phi::DenseTensor&>',
'const paddle::optional<Tensor>&':
'const paddle::optional<phi::DenseTensor>&',
'const paddle::optional<std::vector<Tensor>>&':
'const paddle::optional<std::vector<const phi::DenseTensor*>>&'
}
dense_out_trans_map = {
'Tensor': 'phi::DenseTensor*',
'std::vector<Tensor>': 'std::vector<phi::DenseTensor*>&'
}
sr_input_trans_map = {
'const Tensor&':
'const phi::SelectedRows&',
'const paddle::optional<Tensor>&':
'const paddle::optional<phi::SelectedRows>&'
}
sr_out_trans_map = {'Tensor': 'phi::SelectedRows*'}
input_names = self.inputs['names']
input_infos = self.inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&']
attr_names = self.attrs['names']
kernel_param = self.kernel['param']
if kernel_param is None:
kernel_param = input_names + attr_names
input_name_tensor_map, input_tensor_code = self.gene_input(
kernel_tensor_type, code_indent)
input_tensor_code = input_tensor_code + f"""
{code_indent} if(platform::RecordOpInfoSupplement::IsEnabled()){{"""
single_tensor_names = []
......
......@@ -45,6 +45,26 @@ class ForwardAPI(BaseAPI):
else:
return self.api
def gene_input(self, kernel_tensor_type=None, code_indent=''):
kernel_param = self.kernel['param']
input_name_tensor_map, input_tensor_code = super().gene_input(
kernel_tensor_type, code_indent)
# generate the input that is in view list
for i, input_name in enumerate(self.inputs['names']):
if input_name in self.view_map.values(
) and input_name not in input_name_tensor_map.keys():
if kernel_tensor_type is None or kernel_tensor_type[0][
kernel_param.index(input_name)] == 'dense':
trans_flag = self.gene_trans_flag(input_name)
input_tensor_code = input_tensor_code + f"""
{code_indent} auto {PREFIX_TENSOR_NAME}{input_name} = PrepareData({input_name}, kernel.InputAt(0), {trans_flag});"""
else:
# do nothing
pass
return input_name_tensor_map, input_tensor_code
def parse_intermediate(self, api_item_yaml):
if 'intermediate' in api_item_yaml:
intermediate_outs = [
......@@ -215,11 +235,15 @@ class ForwardAPI(BaseAPI):
if not inplace_flag and self.view_map is not None and self.outputs[
'names'][i] in self.view_map:
if out_dtype_list[i] == 'Tensor':
output_create = output_create + f"""
{code_indent} kernel_out_{i}->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
{code_indent} kernel_out_{i}->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
{code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""
{code_indent} kernel_out_{i}->ShareBufferWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
{code_indent} kernel_out_{i}->ShareInplaceVersionCounterWith(*{PREFIX_TENSOR_NAME}{self.view_map[self.outputs['names'][i]]});
{code_indent} VLOG(3) << "Perform View between Output and Input Tensor, share allocation and inplace version.";"""
else:
raise ValueError(
"{} : Output error: only support Tensor type when use view in yaml. But get {}"
.format(self.api, out_dtype_list[i]))
else:
raise ValueError(
"{} : Output error: the output should not be empty.".format(
......
......@@ -1928,6 +1928,18 @@
output : Tensor(x_grad)
invoke : reverse(out_grad, axis)
- backward_op : rnn_grad
forward : rnn (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor dropout_state_in, float dropout_prob, bool is_bidirec, int input_size, int hidden_size, int num_layers, str mode, int seed, bool is_test) -> Tensor(out), Tensor(dropout_state_out), Tensor[](state), Tensor(reserve)
args : (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor out, Tensor dropout_state_out, Tensor reserve, Tensor out_grad, Tensor[] state_grad, float dropout_prob, bool is_bidirec, int input_size, int hidden_size, int num_layers, str mode, int seed, bool is_test)
output : Tensor(x_grad), Tensor[](pre_state_grad){pre_state.size()}, Tensor[](weight_list_grad){weight_list.size()}
infer_meta :
func : RnnGradInferMeta
param : [x, pre_state, weight_list]
kernel :
func : rnn_grad
data_type: out_grad
optional : sequence_length
- backward_op : roi_align_grad
forward : roi_align (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, bool aligned) -> Tensor(out)
args : (Tensor x, Tensor boxes, Tensor boxes_num, Tensor out_grad, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, bool aligned)
......
......@@ -2991,6 +2991,21 @@
func: overlap_add
backward: overlap_add_grad
- op: rnn
args: (Tensor x, Tensor[] pre_state, Tensor[] weight_list, Tensor sequence_length, Tensor dropout_state_in, float dropout_prob=0.0, bool is_bidirec=false, int input_size=10, int hidden_size=100, int num_layers=1, str mode="RNN_TANH", int seed=0, bool is_test=false)
output: Tensor(out), Tensor(dropout_state_out), Tensor[](state){pre_state.size()}, Tensor(reserve)
infer_meta:
func: RnnInferMeta
param : [x, pre_state, weight_list, sequence_length, dropout_prob, is_bidirec, input_size, hidden_size, num_layers, mode, seed, is_test]
kernel:
func: rnn
param : [x, pre_state, weight_list, sequence_length, dropout_prob, is_bidirec, input_size, hidden_size, num_layers, mode, seed, is_test]
data_type: x
backward: rnn_grad
optional : sequence_length
intermediate : reserve
view : (dropout_state_in -> dropout_state_out)
- op: uniform_random_inplace
args: (Tensor x, float min, float max, int seed, int diag_num, int diag_step, float diag_val)
output: Tensor(out)
......
......@@ -807,6 +807,33 @@ void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad,
}
}
void RnnGradInferMeta(const MetaTensor& x,
const std::vector<const MetaTensor*>& pre_state,
const std::vector<const MetaTensor*>& weight_list,
MetaTensor* x_grad,
std::vector<MetaTensor*> pre_state_grad,
std::vector<MetaTensor*> weight_grad_list) {
PADDLE_ENFORCE_GT(
pre_state.size(),
0UL,
phi::errors::InvalidArgument(
"The input pre_state in RnnGradInferMeta can't be empty."));
PADDLE_ENFORCE_GT(
weight_grad_list.size(),
0UL,
phi::errors::InvalidArgument(
"The input weight_grad_list in RnnGradInferMeta can't be empty."));
if (x_grad) {
UnchangedInferMeta(x, x_grad);
}
if (pre_state_grad.size()) {
UnchangedMultiInferMeta(pre_state, pre_state_grad);
}
if (weight_grad_list.size()) {
UnchangedMultiInferMeta(weight_list, weight_grad_list);
}
}
void ScatterGradInferMeta(const MetaTensor& index,
const MetaTensor& updates,
const MetaTensor& out_grad,
......
......@@ -339,6 +339,13 @@ void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad,
const MetaTensor& x_grad_grad,
MetaTensor* out_grad_grad);
void RnnGradInferMeta(const MetaTensor& x,
const std::vector<const MetaTensor*>& pre_state,
const std::vector<const MetaTensor*>& weight_list,
MetaTensor* x_grad,
std::vector<MetaTensor*> pre_state_grad,
std::vector<MetaTensor*> weight_grad_list);
void ScatterGradInferMeta(const MetaTensor& index,
const MetaTensor& updates,
const MetaTensor& out_grad,
......
......@@ -2496,6 +2496,14 @@ void StackInferMeta(const std::vector<const MetaTensor*>& x,
void UnchangedMultiInferMeta(const std::vector<const MetaTensor*>& x,
std::vector<MetaTensor*> out) {
PADDLE_ENFORCE_EQ(
x.size(),
out.size(),
phi::errors::InvalidArgument(
"Input's size should be equal to the output's size"
"but received input size: (%d) does not equals output_size: (%d)",
x.size(),
out.size()));
for (size_t i = 0; i < x.size(); ++i) {
if (out[i]) {
out[i]->share_meta(*x[i]);
......
......@@ -142,6 +142,30 @@ class TestRNNOp(OpTest):
self.check_grad(set(grad_check_list),
['Out', 'last_hidden', 'last_cell'])
def test_grad_only_input(self):
if not self.is_test:
var_name_list = self.get_weight_names()
grad_check_list = ['Input']
grad_check_list.extend(var_name_list)
self.check_grad(set(grad_check_list),
['Out', 'last_hidden', 'last_cell'])
def test_grad_only_h(self):
if not self.is_test:
var_name_list = self.get_weight_names()
grad_check_list = ['init_h']
grad_check_list.extend(var_name_list)
self.check_grad(set(grad_check_list),
['Out', 'last_hidden', 'last_cell'])
def test_grad_only_c(self):
if not self.is_test:
var_name_list = self.get_weight_names()
grad_check_list = ['init_c']
grad_check_list.extend(var_name_list)
self.check_grad(set(grad_check_list),
['Out', 'last_hidden', 'last_cell'])
class TestRNNOp1(TestRNNOp):
......
......@@ -1012,7 +1012,13 @@ class RNNBase(LayerList):
if not self.time_major:
inputs = paddle.tensor.transpose(inputs, [1, 0, 2])
if in_dynamic_mode():
if in_dygraph_mode():
out, _, state = _C_ops.rnn(
inputs, initial_states, self._all_weights, sequence_length,
self._dropout_state, self.dropout, self.num_directions == 2,
self.input_size, self.hidden_size, self.num_layers, self.mode,
0, not self.training)
elif in_dynamic_mode():
_, _, out, state = _legacy_C_ops.rnn(
inputs, initial_states, self._all_weights, sequence_length,
self._dropout_state, self.state_components, 'dropout_prob',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册