未验证 提交 3a59ede9 编写于 作者: C Chenxiao Niu 提交者: GitHub

[MLU] add rnn backward kernel. (#43969)

上级 9f397f16
...@@ -4616,6 +4616,88 @@ MLURNNDesc::~MLURNNDesc() { ...@@ -4616,6 +4616,88 @@ MLURNNDesc::~MLURNNDesc() {
reservespace_size)); reservespace_size));
} }
/* static */ void MLUCnnl::RNNBackward(const ExecutionContext& ctx,
const cnnlRNNDescriptor_t rnn_desc,
cnnlWgradMode_t add_grad,
const int dev_seq_lengths[],
const void* weight_param_ptr,
void* dweight_param_ptr,
size_t weightspace_size,
const cnnlSeqDataDescriptor_t x_desc,
const void* x,
void* dx,
const cnnlSeqDataDescriptor_t y_desc,
const void* y,
const void* dy,
const cnnlTensorDescriptor_t hx_desc,
const void* hx,
const void* dhy,
void* dhx,
const cnnlTensorDescriptor_t cx_desc,
const void* cx,
const void* dcy,
void* dcx,
void* reservespace_ptr,
size_t reservespace_size) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_NOT_NULL(
rnn_desc,
paddle::platform::errors::Fatal(
"MLU RNNForward failed. rnn_desc initializing failed."));
PADDLE_ENFORCE_NOT_NULL(
x_desc,
paddle::platform::errors::Fatal(
"MLU RNNForward failed. x_desc initializing failed."));
auto& dev_ctx = GetDevCtxFromCTX(ctx);
size_t workspace_size;
Tensor workspace;
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetRNNTempSizes(
handle, rnn_desc, x_desc, &workspace_size, &reservespace_size));
workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(workspace_size)}, dev_ctx);
void* workspace_ptr = workspace.mutable_data(ctx.GetPlace());
PADDLE_ENFORCE_MLU_SUCCESS(cnnlRNNBackwardData(handle,
rnn_desc,
dev_seq_lengths,
y_desc,
y,
dy,
x_desc,
dx,
hx_desc,
hx,
dhy,
dhx,
cx_desc,
cx,
dcy,
dcx,
weight_param_ptr,
weightspace_size,
workspace_ptr,
workspace_size,
reservespace_ptr,
reservespace_size));
PADDLE_ENFORCE_MLU_SUCCESS(cnnlRNNBackwardWeights(handle,
rnn_desc,
add_grad,
dev_seq_lengths,
x_desc,
x,
hx_desc,
hx,
y_desc,
y,
dweight_param_ptr,
weightspace_size,
workspace_ptr,
workspace_size,
reservespace_ptr,
reservespace_size));
}
/* static */ void MLUCnnl::Mask(const ExecutionContext& ctx, /* static */ void MLUCnnl::Mask(const ExecutionContext& ctx,
cnnlMaskedOp_t masked_mode, cnnlMaskedOp_t masked_mode,
const cnnlTensorDescriptor_t input_desc, const cnnlTensorDescriptor_t input_desc,
......
...@@ -1924,6 +1924,30 @@ class MLUCnnl { ...@@ -1924,6 +1924,30 @@ class MLUCnnl {
void* cy, void* cy,
void* reservespace_ptr); void* reservespace_ptr);
static void RNNBackward(const ExecutionContext& ctx,
const cnnlRNNDescriptor_t rnn_desc,
cnnlWgradMode_t add_grad,
const int dev_seq_lengths[],
const void* weight_param_ptr,
void* dweight_param_ptr,
size_t weightspace_size,
const cnnlSeqDataDescriptor_t x_desc,
const void* x,
void* dx,
const cnnlSeqDataDescriptor_t y_desc,
const void* y,
const void* dy,
const cnnlTensorDescriptor_t hx_desc,
const void* hx,
const void* dhy,
void* dhx,
const cnnlTensorDescriptor_t cx_desc,
const void* cx,
const void* dcy,
void* dcx,
void* reservespace_ptr,
size_t reservespace_size);
static void Mask(const ExecutionContext& ctx, static void Mask(const ExecutionContext& ctx,
cnnlMaskedOp_t masked_mode, cnnlMaskedOp_t masked_mode,
const cnnlTensorDescriptor_t input_desc, const cnnlTensorDescriptor_t input_desc,
......
...@@ -28,7 +28,7 @@ void reset_parameter_vector( ...@@ -28,7 +28,7 @@ void reset_parameter_vector(
const std::vector<TensorType>& raw_params_vec, const std::vector<TensorType>& raw_params_vec,
const int& num_layers, const int& num_layers,
const bool& is_bidirec, const bool& is_bidirec,
std::vector<std::vector<std::pair<const T*, size_t>>>* params_vec) { std::vector<std::vector<std::pair<T*, size_t>>>* params_vec) {
// the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers // the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers
// + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to // + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to
// ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers // ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers
...@@ -47,7 +47,8 @@ void reset_parameter_vector( ...@@ -47,7 +47,8 @@ void reset_parameter_vector(
} }
using remove_cv_t = typename std::remove_cv<T>::type; using remove_cv_t = typename std::remove_cv<T>::type;
params_vec->at(i)[j] = std::make_pair( params_vec->at(i)[j] = std::make_pair(
raw_params_vec[tensor_idx]->template data<remove_cv_t>(), const_cast<T*>(
raw_params_vec[tensor_idx]->template data<remove_cv_t>()),
raw_params_vec[tensor_idx]->numel() * sizeof(T)); raw_params_vec[tensor_idx]->numel() * sizeof(T));
} }
} }
...@@ -66,7 +67,6 @@ class RNNMLUKernel : public framework::OpKernel<T> { ...@@ -66,7 +67,6 @@ class RNNMLUKernel : public framework::OpKernel<T> {
// Output // Output
auto state = ctx.MultiOutput<Tensor>("State"); auto state = ctx.MultiOutput<Tensor>("State");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
// auto* dropout_mask = ctx.Output<Tensor>("DropoutState");
auto* reserve_data = ctx.Output<Tensor>("Reserve"); auto* reserve_data = ctx.Output<Tensor>("Reserve");
// Attributes // Attributes
const int& num_layers = ctx.Attr<int>("num_layers"); const int& num_layers = ctx.Attr<int>("num_layers");
...@@ -79,14 +79,6 @@ class RNNMLUKernel : public framework::OpKernel<T> { ...@@ -79,14 +79,6 @@ class RNNMLUKernel : public framework::OpKernel<T> {
sequence_length = ctx.Input<Tensor>("SequenceLength"); sequence_length = ctx.Input<Tensor>("SequenceLength");
} }
// if (dropout_mask->IsInitialized()) {
// if (dropout_mask->numel() != output->numel()) dropout_mask->clear();
// }
// dropout_mask->mutable_data<uint8_t>(output->dims(), ctx.GetPlace());
// auto& dev_ctx = ctx.template device_context<DeviceContext>();
// phi::funcs::SetConstant<platform::XPUDeviceContext, uint8_t> ones;
// ones(dev_ctx, dropout_mask, static_cast<uint8_t>(1));
auto init_h = pre_state[0]; // -> hx auto init_h = pre_state[0]; // -> hx
auto init_c = pre_state[1]; // -> cx auto init_c = pre_state[1]; // -> cx
auto last_h = state[0]; auto last_h = state[0];
...@@ -143,7 +135,7 @@ class RNNMLUKernel : public framework::OpKernel<T> { ...@@ -143,7 +135,7 @@ class RNNMLUKernel : public framework::OpKernel<T> {
init_c->dims()[0])); init_c->dims()[0]));
// weightlist // weightlist
std::vector<std::vector<std::pair<const T*, size_t>>> parameter_lists; std::vector<std::vector<std::pair<T*, size_t>>> parameter_lists;
parameter_lists.resize(num_layers); parameter_lists.resize(num_layers);
reset_parameter_vector( reset_parameter_vector(
weight_list, num_layers, is_bidirec, &parameter_lists); weight_list, num_layers, is_bidirec, &parameter_lists);
...@@ -363,9 +355,390 @@ class RNNMLUKernel : public framework::OpKernel<T> { ...@@ -363,9 +355,390 @@ class RNNMLUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
class RNNMLUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto stream = ctx.template device_context<MLUDeviceContext>().stream();
// get the tensor pointer for the input
auto* input = ctx.Input<Tensor>("Input");
auto pre_state = ctx.MultiInput<Tensor>("PreState");
auto weight_list = ctx.MultiInput<framework::Tensor>("WeightList");
auto* output = ctx.Input<Tensor>("Out");
auto* reserve_data = ctx.Input<Tensor>("Reserve");
const int& num_layers = ctx.Attr<int>("num_layers");
const bool& is_bidirec = ctx.Attr<bool>("is_bidirec");
const int& hidden_size = ctx.Attr<int>("hidden_size");
const std::string& mode = ctx.Attr<std::string>("mode");
bool has_seq_length = ctx.HasInput("SequenceLength");
const Tensor* sequence_length = nullptr;
if (has_seq_length) {
sequence_length = ctx.Input<Tensor>("SequenceLength");
}
PADDLE_ENFORCE_EQ(
mode,
"LSTM",
platform::errors::InvalidArgument(
"XPU only support LSTM mode now, current mode is %s", mode));
auto init_h = pre_state[0]; // -> hx
auto init_c = pre_state[1]; // -> cx
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto state_grad = ctx.MultiInput<Tensor>(framework::GradVarName("State"));
auto last_h_grad = state_grad[0]; // -> dhy
auto last_c_grad = state_grad[1]; // -> dcy
// get the tensor pointer for the output
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
auto weight_grad_list = ctx.MultiOutput<framework::Tensor>(
framework::GradVarName("WeightList"));
auto pre_state_grad =
ctx.MultiOutput<Tensor>(framework::GradVarName("PreState"));
Tensor* init_h_grad = nullptr;
Tensor* init_c_grad = nullptr;
if (pre_state_grad.size() > 0) { // has gradient
init_h_grad = pre_state_grad[0]; // -> dhx
init_c_grad = pre_state_grad[1]; // -> dcx
}
// check shape
const int in_out_dim_num = input->dims().size();
const int& seq_len = input->dims()[0];
const int& batch_size = input->dims()[1];
const int& input_dim = input->dims()[2];
const int& direction_num = is_bidirec ? 2 : 1;
int in_dim_arr[in_out_dim_num] = {seq_len, batch_size, input_dim};
int out_dim_arr[in_out_dim_num] = {
seq_len, batch_size, direction_num * hidden_size};
int proj_size = hidden_size;
PADDLE_ENFORCE_EQ(
num_layers,
1,
platform::errors::InvalidArgument(
"MLU only support 1 num_layers, current num_layers is %s",
num_layers));
PADDLE_ENFORCE_EQ(
init_h->dims()[0],
num_layers * direction_num,
platform::errors::InvalidArgument("The num_layers of in RNN layer must"
" be the same as first dim of init"
"hidden, but received num_layers:%d,"
" dim:%d",
num_layers,
init_h->dims()[0]));
PADDLE_ENFORCE_EQ(
init_c->dims()[0],
num_layers * direction_num,
platform::errors::InvalidArgument(
"The num_layers of in RNN layer must"
" be the same as first dim of cell state hidden, but received"
" num_layers:%d, dim:%d",
num_layers,
init_c->dims()[0]));
std::vector<std::vector<std::pair<T*, size_t>>> parameter_lists;
parameter_lists.resize(num_layers);
reset_parameter_vector(
weight_list, num_layers, is_bidirec, &parameter_lists);
for (unsigned int i = 0; i < weight_grad_list.size(); ++i) {
weight_grad_list[i]->mutable_data<T>(ctx.GetPlace());
}
std::vector<std::vector<std::pair<T*, size_t>>> parameter_lists_grad;
parameter_lists_grad.resize(num_layers);
reset_parameter_vector(
weight_grad_list, num_layers, is_bidirec, &parameter_lists_grad);
// allocate the memory and initization the input_grad
input_grad->mutable_data<T>(input->dims(), ctx.GetPlace());
FillMLUTensorWithHostValue(ctx, static_cast<T>(0.0), input_grad);
Tensor a, b;
Tensor* dynamic_grad_pre_h = &a;
Tensor* dynamic_grad_pre_c = &b;
if (init_h_grad) {
init_h_grad->mutable_data<T>(last_h_grad->dims(), ctx.GetPlace());
FillMLUTensorWithHostValue(ctx, static_cast<T>(0.0), init_h_grad);
} else {
dynamic_grad_pre_h->Resize(last_h_grad->dims());
dynamic_grad_pre_h->mutable_data<T>(ctx.GetPlace());
FillMLUTensorWithHostValue(ctx, static_cast<T>(0.0), dynamic_grad_pre_h);
init_h_grad = dynamic_grad_pre_h;
}
if (init_c_grad) {
init_c_grad->mutable_data<T>(last_c_grad->dims(), ctx.GetPlace());
} else {
dynamic_grad_pre_c->Resize(last_h_grad->dims());
dynamic_grad_pre_c->mutable_data<T>(ctx.GetPlace());
init_c_grad = dynamic_grad_pre_c;
}
std::vector<int> seq_len_vec(batch_size, seq_len);
if (has_seq_length) {
seq_len_vec = operators::GetDataFromTensor(sequence_length);
}
cnnlDirectionMode_t direction =
is_bidirec ? CNNL_RNN_BIDIRECTIONAL : CNNL_RNN_UNIDIRECTIONAL;
MLUSeqDataDesc input_seq_data_desc(CNNL_SEQDATA_TNC,
ToCnnlDataType(input->dtype()),
in_out_dim_num,
in_dim_arr,
static_cast<int>(seq_len_vec.size()),
seq_len_vec.data(),
nullptr);
MLUSeqDataDesc out_seq_data_desc(CNNL_SEQDATA_TNC,
ToCnnlDataType(input->dtype()),
in_out_dim_num,
out_dim_arr,
static_cast<int>(seq_len_vec.size()),
seq_len_vec.data(),
nullptr);
MLUCnnlTensorDesc hx_desc(*init_h);
MLUCnnlTensorDesc cx_desc(*init_c);
MLURNNDesc rnn_desc(CNNL_LSTM,
CNNL_RNN_DOUBLE_BIAS,
direction,
CNNL_RNN_LINEAR_INPUT,
ToCnnlDataType(input->dtype()),
ToCnnlDataType(input->dtype()),
input_dim,
hidden_size,
/*projection*/ proj_size,
num_layers,
nullptr,
CNNL_RNN_PADDED_IO_DISABLED);
rnn_desc.SetRNNMaskMode(CNNL_LSTM_MASK_ENABLED);
// copy weight
size_t weightspace_size;
framework::Tensor weightspace, dweightspace;
PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetRNNWeightSpaceSize(
GetHandleFromCTX(ctx), rnn_desc.get(), &weightspace_size));
weightspace = ctx.AllocateTmpTensor<T, DeviceContext>(
{static_cast<int64_t>(weightspace_size)}, dev_ctx);
dweightspace = ctx.AllocateTmpTensor<T, DeviceContext>(
{static_cast<int64_t>(weightspace_size)}, dev_ctx);
void* weightspace_ptr = weightspace.mutable_data(ctx.GetPlace());
auto w_x = parameter_lists[0][0];
auto w_h = parameter_lists[0][1];
auto b_x = parameter_lists[0][2];
auto b_h = parameter_lists[0][3];
auto actual_total_w_size =
w_x.second + w_h.second + b_x.second + b_h.second;
void* w_x_ptr = weightspace_ptr;
void* w_h_ptr = static_cast<char*>(weightspace_ptr) + w_x.second;
void* b_x_ptr =
static_cast<char*>(weightspace_ptr) + w_x.second + w_h.second;
void* b_h_ptr = static_cast<char*>(weightspace_ptr) + w_x.second +
w_h.second + b_x.second;
memory::Copy(weightspace.place(),
w_x_ptr,
weightspace.place(),
w_x.first,
w_x.second,
stream);
memory::Copy(weightspace.place(),
w_h_ptr,
weightspace.place(),
w_h.first,
w_h.second,
stream);
memory::Copy(weightspace.place(),
b_x_ptr,
weightspace.place(),
b_x.first,
b_x.second,
stream);
memory::Copy(weightspace.place(),
b_h_ptr,
weightspace.place(),
b_h.first,
b_h.second,
stream);
if (is_bidirec) {
auto bw_x = parameter_lists[0][4];
auto bw_h = parameter_lists[0][5];
auto bb_x = parameter_lists[0][6];
auto bb_h = parameter_lists[0][7];
void* bw_x_ptr =
static_cast<char*>(weightspace_ptr) + actual_total_w_size;
void* bw_h_ptr = static_cast<char*>(weightspace_ptr) +
actual_total_w_size + bw_x.second;
void* bb_x_ptr = static_cast<char*>(weightspace_ptr) +
actual_total_w_size + bw_x.second + bw_h.second;
void* bb_h_ptr = static_cast<char*>(weightspace_ptr) +
actual_total_w_size + bw_x.second + bw_h.second +
bb_x.second;
actual_total_w_size +=
bw_x.second + bw_h.second + bb_x.second + bb_h.second;
memory::Copy(weightspace.place(),
bw_x_ptr,
weightspace.place(),
bw_x.first,
bw_x.second,
stream);
memory::Copy(weightspace.place(),
bw_h_ptr,
weightspace.place(),
bw_h.first,
bw_h.second,
stream);
memory::Copy(weightspace.place(),
bb_x_ptr,
weightspace.place(),
bb_x.first,
bb_x.second,
stream);
memory::Copy(weightspace.place(),
bb_h_ptr,
weightspace.place(),
bb_h.first,
bb_h.second,
stream);
}
dev_ctx.Wait();
PADDLE_ENFORCE_EQ(weightspace_size,
actual_total_w_size,
platform::errors::InvalidArgument(
"The weightsize doesn't match"
" weightspace_size:%d, actual_total_w_size:%d",
weightspace_size,
actual_total_w_size));
MLUCnnl::RNNBackward(ctx,
rnn_desc.get(),
CNNL_WGRAD_MODE_SET,
seq_len_vec.data(),
GetBasePtr(&weightspace),
GetBasePtr(&dweightspace),
weightspace.numel() * sizeof(T),
input_seq_data_desc.get(),
GetBasePtr(input),
GetBasePtr(input_grad),
out_seq_data_desc.get(),
GetBasePtr(output),
GetBasePtr(output_grad),
hx_desc.get(),
GetBasePtr(init_h),
GetBasePtr(last_h_grad),
GetBasePtr(init_h_grad),
cx_desc.get(),
GetBasePtr(init_c),
GetBasePtr(last_c_grad),
GetBasePtr(init_c_grad),
const_cast<void*>(GetBasePtr(reserve_data)),
reserve_data->numel() * sizeof(T));
void* dweightspace_ptr = dweightspace.mutable_data(ctx.GetPlace());
auto dw_x = parameter_lists_grad[0][0];
auto dw_h = parameter_lists_grad[0][1];
auto db_x = parameter_lists_grad[0][2];
auto db_h = parameter_lists_grad[0][3];
auto dactual_total_w_size =
dw_x.second + dw_h.second + db_x.second + db_h.second;
void* dw_x_ptr = dweightspace_ptr;
void* dw_h_ptr = static_cast<char*>(dweightspace_ptr) + dw_x.second;
void* db_x_ptr =
static_cast<char*>(dweightspace_ptr) + dw_x.second + dw_h.second;
void* db_h_ptr = static_cast<char*>(dweightspace_ptr) + dw_x.second +
dw_h.second + db_x.second;
memory::Copy(weightspace.place(),
dw_x.first,
weightspace.place(),
dw_x_ptr,
dw_x.second,
stream);
memory::Copy(weightspace.place(),
dw_h.first,
weightspace.place(),
dw_h_ptr,
dw_h.second,
stream);
memory::Copy(weightspace.place(),
db_x.first,
weightspace.place(),
db_x_ptr,
db_x.second,
stream);
memory::Copy(weightspace.place(),
db_h.first,
weightspace.place(),
db_h_ptr,
db_h.second,
stream);
if (is_bidirec) {
auto dbw_x = parameter_lists_grad[0][4];
auto dbw_h = parameter_lists_grad[0][5];
auto dbb_x = parameter_lists_grad[0][6];
auto dbb_h = parameter_lists_grad[0][7];
void* dbw_x_ptr =
static_cast<char*>(dweightspace_ptr) + dactual_total_w_size;
void* dbw_h_ptr = static_cast<char*>(dweightspace_ptr) +
dactual_total_w_size + dbw_x.second;
void* dbb_x_ptr = static_cast<char*>(dweightspace_ptr) +
dactual_total_w_size + dbw_x.second + dbw_h.second;
void* dbb_h_ptr = static_cast<char*>(dweightspace_ptr) +
dactual_total_w_size + dbw_x.second + dbw_h.second +
dbb_x.second;
dactual_total_w_size +=
dbw_x.second + dbw_h.second + dbb_x.second + dbb_h.second;
memory::Copy(weightspace.place(),
dbw_x.first,
weightspace.place(),
dbw_x_ptr,
dbw_x.second,
stream);
memory::Copy(weightspace.place(),
dbw_h.first,
weightspace.place(),
dbw_h_ptr,
dbw_h.second,
stream);
memory::Copy(weightspace.place(),
dbb_x.first,
weightspace.place(),
dbb_x_ptr,
dbb_x.second,
stream);
memory::Copy(weightspace.place(),
dbb_h.first,
weightspace.place(),
dbb_h_ptr,
dbb_h.second,
stream);
}
dev_ctx.Wait();
PADDLE_ENFORCE_EQ(weightspace_size,
dactual_total_w_size,
platform::errors::InvalidArgument(
"The weightsize doesn't match"
" weightspace_size:%d, dactual_total_w_size:%d",
weightspace_size,
dactual_total_w_size));
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_MLU_KERNEL( REGISTER_OP_MLU_KERNEL(
rnn, ops::RNNMLUKernel<paddle::platform::MLUDeviceContext, float>); rnn, ops::RNNMLUKernel<paddle::platform::MLUDeviceContext, float>);
REGISTER_OP_MLU_KERNEL(
rnn_grad, ops::RNNMLUGradKernel<paddle::platform::MLUDeviceContext, float>);
...@@ -135,43 +135,50 @@ class TestRNNOp(OpTest): ...@@ -135,43 +135,50 @@ class TestRNNOp(OpTest):
def test_output(self): def test_output(self):
self.check_output_with_place( self.check_output_with_place(
self.place, no_check_set=['Reserve', 'DropoutState', 'State']) self.place,
atol=1e-4,
no_check_set=['Reserve', 'DropoutState', 'State'])
def set_attrs(self): def set_attrs(self):
pass pass
# def test_grad(self): def test_grad(self):
# if not self.is_test: if not self.is_test and self.sequence_length is None:
# var_name_list = self.get_weight_names() # if not self.is_test:
# grad_check_list = ['Input', 'init_h', 'init_c'] var_name_list = self.get_weight_names()
# grad_check_list.extend(var_name_list) grad_check_list = ['Input', 'init_h', 'init_c']
# self.check_grad_with_place(self.place, set(grad_check_list), grad_check_list.extend(var_name_list)
# ['Out', 'last_hidden', 'last_cell']) self.check_grad_with_place(self.place, set(grad_check_list),
['Out', 'last_hidden', 'last_cell'])
# class TestRNNOp1(TestRNNOp): class TestRNNOp1(TestRNNOp):
# def set_attrs(self): def set_attrs(self):
# self.sequence_length = None self.sequence_length = None
# class TestRNNOp2(TestRNNOp):
# def set_attrs(self): class TestRNNOp2(TestRNNOp):
# self.sequence_length = None
# self.is_bidirec = True
# class TestRNNOp3(TestRNNOp): def set_attrs(self):
self.sequence_length = None
self.is_bidirec = True
# def set_attrs(self):
# self.is_test = True
# self.sequence_length = None
# class TestRNNOp4(TestRNNOp): class TestRNNOp3(TestRNNOp):
def set_attrs(self):
self.is_test = True
self.sequence_length = None
class TestRNNOp4(TestRNNOp):
def set_attrs(self):
self.is_test = True
self.sequence_length = None
self.is_bidirec = True
# def set_attrs(self):
# self.is_test = True
# self.sequence_length = None
# self.is_bidirec = True
#TODO(chenxiao): cnnl doesn't support num_layers > 1 case #TODO(chenxiao): cnnl doesn't support num_layers > 1 case
# class TestRNNOp5(TestRNNOp): # class TestRNNOp5(TestRNNOp):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册