diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 1b40a4e74fa844fc37f9b9510c773f0d895781ee..c0619145ad5ab336de5541d3cbe74226d9661b86 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -620,6 +620,80 @@ MLUCnnlDCNDesc::~MLUCnnlDCNDesc() { } } +MLUSeqDataDesc::MLUSeqDataDesc(cnnlSeqDataLayout_t layout, + cnnlDataType_t dtype, + int dimNb, + const int dimSize[], + int seqLengthArraySize, + const int seqLengthArray[], + void* paddingFill) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateSeqDataDescriptor(&seq_data_desc_)); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetSeqDataDescriptor(seq_data_desc_, + layout, + dtype, + dimNb, + dimSize, + seqLengthArraySize, + seqLengthArray, + paddingFill)); +} + +const cnnlSeqDataDescriptor_t MLUSeqDataDesc::get() const { + return seq_data_desc_; +} + +MLUSeqDataDesc::~MLUSeqDataDesc() { + if (seq_data_desc_) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlDestroySeqDataDescriptor(seq_data_desc_)); + } +} + +MLURNNDesc::MLURNNDesc(const int hidden_size, + const int num_layers, + const cnnlRNNInputMode_t input_mode, + const cnnlDirectionMode_t direction, + const cnnlRNNMode_t rnn_mode) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateRNNDescriptor(&rnn_desc_)); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetRNNDescriptor( + rnn_desc_, hidden_size, num_layers, input_mode, direction, rnn_mode)); +} + +MLURNNDesc::MLURNNDesc(cnnlRNNMode_t cell_mode, + cnnlRNNBiasMode_t bias_mode, + cnnlDirectionMode_t direction, + cnnlRNNInputMode_t input_mode, + cnnlDataType_t data_type, + cnnlDataType_t math_prec, + int input_size, + int hidden_size, + int proj_size, + int layer_num, + void* dropout_desc, + cnnlRNNPaddingMode_t padding_mode) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlCreateRNNDescriptor(&rnn_desc_)); + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetRNNDescriptor_v2(rnn_desc_, + cell_mode, + bias_mode, + direction, + input_mode, + data_type, + math_prec, + input_size, + hidden_size, + proj_size, + layer_num, + dropout_desc, + padding_mode)); +} + +const cnnlRNNDescriptor_t MLURNNDesc::get() const { return rnn_desc_; } + +MLURNNDesc::~MLURNNDesc() { + if (rnn_desc_) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlDestroyRNNDescriptor(rnn_desc_)); + } +} + /* static */ void MLUCnnl::Active(const ExecutionContext& ctx, cnnlActivationDescriptor_t active_desc, const cnnlTensorDescriptor_t input_desc, @@ -4471,6 +4545,105 @@ MLUCnnlDCNDesc::~MLUCnnlDCNDesc() { output)); } +/* static */ void MLUCnnl::RNNForward(const ExecutionContext& ctx, + const cnnlRNNDescriptor_t rnn_desc, + const int dev_seq_lengths[], + const void* weight_param_ptr, + size_t weightspace_size, + const cnnlSeqDataDescriptor_t x_desc, + const void* x, + const cnnlSeqDataDescriptor_t y_desc, + void* y, + const cnnlTensorDescriptor_t h_desc, + const void* hx, + void* hy, + const cnnlTensorDescriptor_t c_desc, + const void* cx, + void* cy, + void* reservespace_ptr) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + // make sure 1. cnnlSetRNNDescriptor_v2 is invoked + // 2. x_desc is not NULL + PADDLE_ENFORCE_NOT_NULL( + rnn_desc, + paddle::platform::errors::Fatal( + "MLU RNNForward failed. rnn_desc initializing failed.")); + PADDLE_ENFORCE_NOT_NULL( + x_desc, + paddle::platform::errors::Fatal( + "MLU RNNForward failed. x_desc initializing failed.")); + auto& dev_ctx = GetDevCtxFromCTX(ctx); + size_t workspace_size, reservespace_size; + Tensor workspace; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetRNNTempSizes( + handle, rnn_desc, x_desc, &workspace_size, &reservespace_size)); + workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlRNNForwardTraining(handle, + rnn_desc, + dev_seq_lengths, + x_desc, + x, + y_desc, + y, + h_desc, + hx, + hy, + c_desc, + cx, + cy, + weight_param_ptr, + weightspace_size, + workspace_ptr, + workspace_size, + reservespace_ptr, + reservespace_size)); +} + +/* static */ void MLUCnnl::Mask(const ExecutionContext& ctx, + cnnlMaskedOp_t masked_mode, + const cnnlTensorDescriptor_t input_desc, + const void* input, + const cnnlTensorDescriptor_t masked_desc, + const void* masked, + const cnnlTensorDescriptor_t value_desc, + const void* value, + const cnnlTensorDescriptor_t output_desc, + void* output, + uint32_t* number) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + auto& dev_ctx = GetDevCtxFromCTX(ctx); + size_t workspace_size; + Tensor workspace; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetMaskedWorkspaceSize(handle, + masked_mode, + input_desc, + masked_desc, + value_desc, + output_desc, + &workspace_size)); + workspace = ctx.AllocateTmpTensor( + {static_cast(workspace_size)}, dev_ctx); + void* workspace_ptr = workspace.mutable_data(ctx.GetPlace()); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlMasked_v3(handle, + masked_mode, + input_desc, + input, + masked_desc, + masked, + value_desc, + value, + workspace_ptr, + workspace_size, + output_desc, + output, + number)); +} + /* static */ void MLUCnnl::BceWithLogits( const ExecutionContext& ctx, cnnlBceWithLogitsReduction_t reduction, diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 8dcdd33e346d79009996dc8eab26d71cdc49dc66..9031040ec55984ddd7e76ce7b2d8a2e2dd6c15f4 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -495,6 +495,90 @@ class MLUCnnlDCNDesc { cnnlDCNDescriptor_t dcn_desc_ = nullptr; }; +class MLUSeqDataDesc { + public: + MLUSeqDataDesc(const MLUSeqDataDesc& desc) = delete; + MLUSeqDataDesc& operator=(const MLUSeqDataDesc& desc) = delete; + + MLUSeqDataDesc(cnnlSeqDataLayout_t layout, + cnnlDataType_t dtype, + int dimNb, + const int dimSize[], + int seqLengthArraySize, + const int seqLengthArray[], + void* paddingFill); + + const cnnlSeqDataDescriptor_t get() const; + + ~MLUSeqDataDesc(); + + private: + cnnlSeqDataDescriptor_t seq_data_desc_ = nullptr; +}; + +class MLURNNDesc { + public: + MLURNNDesc(const MLURNNDesc& desc) = delete; + MLURNNDesc& operator=(const MLURNNDesc& desc) = delete; + + MLURNNDesc(const int hidden_size, + const int num_layers, + const cnnlRNNInputMode_t input_mode, + const cnnlDirectionMode_t direction, + const cnnlRNNMode_t rnn_mode); + + MLURNNDesc(cnnlRNNMode_t cell_mode, + cnnlRNNBiasMode_t bias_mode, + cnnlDirectionMode_t direction, + cnnlRNNInputMode_t input_mode, + cnnlDataType_t data_type, + cnnlDataType_t math_prec, + int input_size, + int hidden_size, + int proj_size, + int layer_num, + void* dropout_desc, + cnnlRNNPaddingMode_t padding_mode); + + void SetRNNProjectionLayers(const int rec_proj_size, + const int out_proj_size) { + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlSetRNNProjectionLayers(rnn_desc_, rec_proj_size, out_proj_size)); + } + + void SetPeepholeMode(const cnnlRNNPeepholeMode_t peephole_mode) { + PADDLE_ENFORCE_MLU_SUCCESS( + cnnlSetRNNPeepholeMode(rnn_desc_, peephole_mode)); + } + + void SetRNNBiasMode(const cnnlRNNBiasMode_t bias_mode) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetRNNBiasMode(rnn_desc_, bias_mode)); + } + + void SetRNNMaskMode(const cnnlRNNMaskMode_t mask_mode) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetRNNMaskMode(rnn_desc_, mask_mode)); + } + + void SetRNNClip(const cnnlRNNClipMode_t clip_mode, + const cnnlNanPropagation_t clip_nan_opt, + const double left_clip, + const double right_clip) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetRNNClip( + rnn_desc_, clip_mode, clip_nan_opt, left_clip, right_clip)); + } + + void SetRNNPaddingMode(const cnnlRNNPaddingMode_t padding_mode) { + PADDLE_ENFORCE_MLU_SUCCESS(cnnlSetRNNPaddingMode(rnn_desc_, padding_mode)); + } + + const cnnlRNNDescriptor_t get() const; + + ~MLURNNDesc(); + + private: + cnnlRNNDescriptor_t rnn_desc_ = nullptr; +}; + class MLUCnnl { public: static void Active(const ExecutionContext& ctx, @@ -1814,6 +1898,35 @@ class MLUCnnl { const cnnlTensorDescriptor_t output_desc, void* output); + static void RNNForward(const ExecutionContext& ctx, + const cnnlRNNDescriptor_t rnn_desc, + const int dev_seq_lengths[], + const void* weight_param_ptr, + size_t weightspace_size, + const cnnlSeqDataDescriptor_t x_desc, + const void* x, + const cnnlSeqDataDescriptor_t y_desc, + void* y, + const cnnlTensorDescriptor_t h_desc, + const void* hx, + void* hy, + const cnnlTensorDescriptor_t c_desc, + const void* cx, + void* cy, + void* reservespace_ptr); + + static void Mask(const ExecutionContext& ctx, + cnnlMaskedOp_t masked_mode, + const cnnlTensorDescriptor_t input_desc, + const void* input, + const cnnlTensorDescriptor_t masked_desc, + const void* masked, + const cnnlTensorDescriptor_t value_desc, + const void* value, + const cnnlTensorDescriptor_t output_desc, + void* output, + uint32_t* number); + static void Transform(const ExecutionContext& ctx, const void* alpha, const void* beta, diff --git a/paddle/fluid/operators/rnn_op_mlu.cc b/paddle/fluid/operators/rnn_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..653c50c83b83e185d108a89afe8938882d619e15 --- /dev/null +++ b/paddle/fluid/operators/rnn_op_mlu.cc @@ -0,0 +1,371 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" +#include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/platform/device/device_wrapper.h" +#include "paddle/fluid/platform/device/xpu/xpu_header.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; +using TensorList = std::vector; +template +void reset_parameter_vector( + const std::vector& raw_params_vec, + const int& num_layers, + const bool& is_bidirec, + std::vector>>* params_vec) { + // the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers + // + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to + // ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers + const int& direction_num = is_bidirec ? 2 : 1; + const int& layer_weight_size = 4 * direction_num; + const int& all_weight_size = num_layers * layer_weight_size; + const int& bias_start_idx = all_weight_size / 2; + for (int i = 0; i < num_layers; i++) { + params_vec->at(i).resize(layer_weight_size); + for (int j = 0; j < layer_weight_size; j++) { + int k = j % 4; + const int& section = j / 4; + int tensor_idx = i * 2 * direction_num + section * 2 + k % 2; + if (k >= 2) { + tensor_idx += bias_start_idx; + } + using remove_cv_t = typename std::remove_cv::type; + params_vec->at(i)[j] = std::make_pair( + raw_params_vec[tensor_idx]->template data(), + raw_params_vec[tensor_idx]->numel() * sizeof(T)); + } + } +} + +template +class RNNMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + // Input + auto& dev_ctx = GetDevCtxFromCTX(ctx); + auto* input = ctx.Input("Input"); + auto pre_state = ctx.MultiInput("PreState"); + auto weight_list = ctx.MultiInput("WeightList"); + bool has_seq_length = ctx.HasInput("SequenceLength"); + // Output + auto state = ctx.MultiOutput("State"); + auto* output = ctx.Output("Out"); + // auto* dropout_mask = ctx.Output("DropoutState"); + auto* reserve_data = ctx.Output("Reserve"); + // Attributes + const int& num_layers = ctx.Attr("num_layers"); + const bool& is_bidirec = ctx.Attr("is_bidirec"); + const int& hidden_size = ctx.Attr("hidden_size"); + const std::string& mode = ctx.Attr("mode"); + + const Tensor* sequence_length = nullptr; + if (has_seq_length) { + sequence_length = ctx.Input("SequenceLength"); + } + + // if (dropout_mask->IsInitialized()) { + // if (dropout_mask->numel() != output->numel()) dropout_mask->clear(); + // } + // dropout_mask->mutable_data(output->dims(), ctx.GetPlace()); + // auto& dev_ctx = ctx.template device_context(); + // phi::funcs::SetConstant ones; + // ones(dev_ctx, dropout_mask, static_cast(1)); + + auto init_h = pre_state[0]; // -> hx + auto init_c = pre_state[1]; // -> cx + auto last_h = state[0]; + auto last_c = state[1]; + + // check shape + const int in_out_dim_num = input->dims().size(); + const int& seq_len = input->dims()[0]; // time_step + const int& batch_size = input->dims()[1]; + const int& input_dim = input->dims()[2]; + const int& direction_num = is_bidirec ? 2 : 1; + int in_dim_arr[in_out_dim_num] = {seq_len, batch_size, input_dim}; + int out_dim_arr[in_out_dim_num] = { + seq_len, batch_size, direction_num * hidden_size}; + int proj_size = hidden_size; + + std::vector seq_len_vec(batch_size, seq_len); + if (has_seq_length) { // set seq_len if no padding, otherwise seq_len for + // each element. + seq_len_vec = operators::GetDataFromTensor(sequence_length); + } + cnnlDirectionMode_t direction = + is_bidirec ? CNNL_RNN_BIDIRECTIONAL : CNNL_RNN_UNIDIRECTIONAL; + + PADDLE_ENFORCE_EQ( + mode, + "LSTM", + platform::errors::InvalidArgument( + "MLU only support LSTM mode now, current mode is %s", mode)); + PADDLE_ENFORCE_EQ( + num_layers, + 1, + platform::errors::InvalidArgument( + "MLU only support 1 num_layers, current num_layers is %s", + num_layers)); + PADDLE_ENFORCE_EQ( + init_h->dims()[0], + num_layers * direction_num, + platform::errors::InvalidArgument("The num_layers of in RNN layer must" + " be the same as first dim of init " + "hidden, but received num_layers:%d," + " dim:%d", + num_layers, + init_h->dims()[0])); + + PADDLE_ENFORCE_EQ( + init_c->dims()[0], + num_layers * direction_num, + platform::errors::InvalidArgument( + "The num_layers of in RNN layer must" + " be the same as first dim of cell state hidden, but received" + " num_layers:%d, dim:%d", + num_layers, + init_c->dims()[0])); + + // weightlist + std::vector>> parameter_lists; + parameter_lists.resize(num_layers); + reset_parameter_vector( + weight_list, num_layers, is_bidirec, ¶meter_lists); + + // init the output and allocate the memory + output->mutable_data(ctx.GetPlace()); // -> y in cnnl + last_h->mutable_data(ctx.GetPlace()); // -> hy in cnnl + last_c->mutable_data(ctx.GetPlace()); // -> cy in cnnl + + MLUSeqDataDesc input_seq_data_desc(CNNL_SEQDATA_TNC, + ToCnnlDataType(input->dtype()), + in_out_dim_num, + in_dim_arr, + static_cast(seq_len_vec.size()), + seq_len_vec.data(), + nullptr); + MLUSeqDataDesc out_seq_data_desc(CNNL_SEQDATA_TNC, + ToCnnlDataType(input->dtype()), + in_out_dim_num, + out_dim_arr, + static_cast(seq_len_vec.size()), + seq_len_vec.data(), + nullptr); + MLUCnnlTensorDesc hx_desc(*init_h); + MLUCnnlTensorDesc cx_desc(*init_c); + + MLURNNDesc rnn_desc(CNNL_LSTM, + CNNL_RNN_DOUBLE_BIAS, + direction, + CNNL_RNN_LINEAR_INPUT, + ToCnnlDataType(input->dtype()), + ToCnnlDataType(input->dtype()), + input_dim, + hidden_size, + /*projection*/ proj_size, + num_layers, + nullptr, + CNNL_RNN_PADDED_IO_DISABLED); + rnn_desc.SetRNNMaskMode(CNNL_LSTM_MASK_ENABLED); + + // copy weight params + size_t weightspace_size; + framework::Tensor weightspace; + PADDLE_ENFORCE_MLU_SUCCESS(cnnlGetRNNWeightSpaceSize( + GetHandleFromCTX(ctx), rnn_desc.get(), &weightspace_size)); + + weightspace = ctx.AllocateTmpTensor( + {static_cast(weightspace_size)}, dev_ctx); + void* weightspace_ptr = weightspace.mutable_data(ctx.GetPlace()); + auto w_x = parameter_lists[0][0]; + auto w_h = parameter_lists[0][1]; + auto b_x = parameter_lists[0][2]; + auto b_h = parameter_lists[0][3]; + auto actual_total_w_size = + w_x.second + w_h.second + b_x.second + b_h.second; + + void* w_x_ptr = weightspace_ptr; + void* w_h_ptr = static_cast(weightspace_ptr) + w_x.second; + void* b_x_ptr = + static_cast(weightspace_ptr) + w_x.second + w_h.second; + void* b_h_ptr = static_cast(weightspace_ptr) + w_x.second + + w_h.second + b_x.second; + + memory::Copy(weightspace.place(), + w_x_ptr, + weightspace.place(), + w_x.first, + w_x.second, + nullptr); + memory::Copy(weightspace.place(), + w_h_ptr, + weightspace.place(), + w_h.first, + w_h.second, + nullptr); + memory::Copy(weightspace.place(), + b_x_ptr, + weightspace.place(), + b_x.first, + b_x.second, + nullptr); + memory::Copy(weightspace.place(), + b_h_ptr, + weightspace.place(), + b_h.first, + b_h.second, + nullptr); + + if (is_bidirec) { + auto bw_x = parameter_lists[0][4]; + auto bw_h = parameter_lists[0][5]; + auto bb_x = parameter_lists[0][6]; + auto bb_h = parameter_lists[0][7]; + void* bw_x_ptr = + static_cast(weightspace_ptr) + actual_total_w_size; + void* bw_h_ptr = static_cast(weightspace_ptr) + + actual_total_w_size + bw_x.second; + void* bb_x_ptr = static_cast(weightspace_ptr) + + actual_total_w_size + bw_x.second + bw_h.second; + void* bb_h_ptr = static_cast(weightspace_ptr) + + actual_total_w_size + bw_x.second + bw_h.second + + bb_x.second; + actual_total_w_size += + bw_x.second + bw_h.second + bb_x.second + bb_h.second; + + memory::Copy(weightspace.place(), + bw_x_ptr, + weightspace.place(), + bw_x.first, + bw_x.second, + nullptr); + memory::Copy(weightspace.place(), + bw_h_ptr, + weightspace.place(), + bw_h.first, + bw_h.second, + nullptr); + memory::Copy(weightspace.place(), + bb_x_ptr, + weightspace.place(), + bb_x.first, + bb_x.second, + nullptr); + memory::Copy(weightspace.place(), + bb_h_ptr, + weightspace.place(), + bb_h.first, + bb_h.second, + nullptr); + } + + PADDLE_ENFORCE_EQ(weightspace_size, + actual_total_w_size, + platform::errors::InvalidArgument( + "The weightsize doesn't match" + " weightspace_size:%d, actual_total_w_size:%d", + weightspace_size, + actual_total_w_size)); + + // get reservespace_ptr + int gate_num = 4; + int hidden_data_idx = (num_layers - 1); + hidden_data_idx += (gate_num + 1) * num_layers; + const int& block_size = direction_num * seq_len * batch_size * hidden_size; + reserve_data->Resize({hidden_data_idx, block_size}); + + reserve_data->mutable_data(ctx.GetPlace()); + + MLUCnnl::RNNForward(ctx, + rnn_desc.get(), + seq_len_vec.data(), + weightspace_ptr, + weightspace_size, + input_seq_data_desc.get(), + GetBasePtr(input), + out_seq_data_desc.get(), + GetBasePtr(output), + hx_desc.get(), + GetBasePtr(init_h), + GetBasePtr(last_h), + cx_desc.get(), + GetBasePtr(init_c), + GetBasePtr(last_c), + GetBasePtr(reserve_data)); + + if (has_seq_length) { + // if has_seq_length, do mask out the output of cnnlRNNForwardTraining + auto masked_mode = CNNL_MASKED_FILL; + float off_value = 0.0f; + + framework::Tensor on_value_tensor(input->dtype()); + framework::Tensor masked_tensor(framework::TransToPhiDataType(VT::INT8)); + framework::Tensor h_masked_tensor( + framework::TransToPhiDataType(VT::INT8)); + on_value_tensor.Resize({1}); + masked_tensor.Resize({seq_len, batch_size, direction_num * hidden_size}); + h_masked_tensor.Resize( + {seq_len, batch_size, direction_num * hidden_size}); + + on_value_tensor.mutable_data(ctx.GetPlace()); + masked_tensor.mutable_data(ctx.GetPlace()); + int8_t* h_masked_ptr = + h_masked_tensor.mutable_data(platform::CPUPlace()); + + for (int t = 0; t < seq_len; ++t) { + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < direction_num * hidden_size; ++c) { + auto tmp_seq_len = seq_len_vec[n]; + auto offset = t * batch_size * direction_num * hidden_size + + n * direction_num * hidden_size + c; + *(h_masked_ptr + offset) = t >= tmp_seq_len ? 1 : 0; + } + } + } + + framework::TensorCopy( + h_masked_tensor, ctx.GetPlace(), dev_ctx, &masked_tensor); + dev_ctx.Wait(); + + FillMLUTensorWithHostValue(ctx, off_value, &on_value_tensor); + MLUCnnlTensorDesc on_value_desc(on_value_tensor); + MLUCnnlTensorDesc output_desc(*output); + MLUCnnlTensorDesc masked_desc(masked_tensor); + + MLUCnnl::Mask(ctx, + masked_mode, + output_desc.get(), + GetBasePtr(output), + masked_desc.get(), + GetBasePtr(&masked_tensor), + on_value_desc.get(), + GetBasePtr(&on_value_tensor), + output_desc.get(), + GetBasePtr(output), + nullptr); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_MLU_KERNEL( + rnn, ops::RNNMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_rnn_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_rnn_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..f1aabbd3b603bd54647c496ec5dc448897aaeb8f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_rnn_op_mlu.py @@ -0,0 +1,208 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import math +import paddle.fluid.core as core +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import random +import sys + +sys.path.append('..') +from op_test import OpTest + +sys.path.append("../rnn") +from rnn_numpy import SimpleRNN, LSTM, GRU +from convert import get_params_for_net + +random.seed(2) +np.set_printoptions(threshold=np.inf) +paddle.enable_static() + + +class TestRNNOp(OpTest): + + def get_weight_names(self): + weight_names = [] + for i in range(self.num_layers): + for j in range(0, 2 * self.direction_num): + weight_names.append("{}.weight_{}".format(i, j)) + for i in range(self.num_layers): + for j in range(0, 2 * self.direction_num): + weight_names.append("{}.bias_{}".format(i, j)) + return weight_names + + def setUp(self): + self.__class__.use_mlu = True + self.place = paddle.device.MLUPlace(0) + self.in_type = np.float32 + self.init_dtype() + self.init_size() + self.op_type = "rnn" + self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32) + self.num_layers = 1 + self.is_bidirec = False + self.mode = "LSTM" + self.is_test = False + self.dropout = 0.0 + self.set_attrs() + + self.direction_num = 2 if self.is_bidirec else 1 + direction = "bidirectional" if self.is_bidirec else "forward" + + input = np.random.uniform(low=-0.1, + high=0.1, + size=(self.seq_length, self.batch_size, + self.input_size)).astype(self.dtype) + + input[11][1:][:] = 0 + input[10][2:][:] = 0 + input[9][3:][:] = 0 + input[8][4:][:] = 0 + + rnn1 = LSTM(self.input_size, + self.hidden_size, + num_layers=self.num_layers, + time_major=True, + direction=direction, + dropout=self.dropout, + dtype=self.dtype) + + flat_w = get_params_for_net(rnn1) + output, (last_hidden, + last_cell) = rnn1(input, sequence_length=self.sequence_length) + + init_h = np.zeros( + (self.num_layers * self.direction_num, self.batch_size, + self.hidden_size)).astype(self.dtype) + init_c = np.zeros( + (self.num_layers * self.direction_num, self.batch_size, + self.hidden_size)).astype(self.dtype) + state_out = np.ndarray((300)).astype("uint8") + + self.inputs = { + 'Input': input, + 'WeightList': flat_w, + 'PreState': [('init_h', init_h), ('init_c', init_c)], + 'SequenceLength': self.sequence_length + } + if self.sequence_length is None: + self.inputs = { + 'Input': input, + 'WeightList': flat_w, + 'PreState': [('init_h', init_h), ('init_c', init_c)], + } + self.attrs = { + 'dropout_prob': self.dropout, + 'is_bidirec': self.is_bidirec, + 'input_size': self.input_size, + 'hidden_size': self.hidden_size, + 'num_layers': self.num_layers, + 'mode': self.mode, + 'is_test': self.is_test + } + self.outputs = { + 'Out': output, + "State": [('last_hidden', last_hidden), ('last_cell', last_cell)], + 'Reserve': np.ndarray((400)).astype("uint8"), + 'DropoutState': state_out + } + + def init_dtype(self): + self.dtype = self.in_type + + def init_size(self): + self.seq_length = 12 + self.batch_size = 5 + self.input_size = 3 + self.hidden_size = 2 + + def test_output(self): + self.check_output_with_place( + self.place, no_check_set=['Reserve', 'DropoutState', 'State']) + + def set_attrs(self): + pass + + # def test_grad(self): + # if not self.is_test: + # var_name_list = self.get_weight_names() + # grad_check_list = ['Input', 'init_h', 'init_c'] + # grad_check_list.extend(var_name_list) + # self.check_grad_with_place(self.place, set(grad_check_list), + # ['Out', 'last_hidden', 'last_cell']) + + +# class TestRNNOp1(TestRNNOp): + +# def set_attrs(self): +# self.sequence_length = None + +# class TestRNNOp2(TestRNNOp): + +# def set_attrs(self): +# self.sequence_length = None +# self.is_bidirec = True + +# class TestRNNOp3(TestRNNOp): + +# def set_attrs(self): +# self.is_test = True +# self.sequence_length = None + +# class TestRNNOp4(TestRNNOp): + +# def set_attrs(self): +# self.is_test = True +# self.sequence_length = None +# self.is_bidirec = True + +#TODO(chenxiao): cnnl doesn't support num_layers > 1 case +# class TestRNNOp5(TestRNNOp): + +# def set_attrs(self): +# self.num_layers = 2 + +# class TestRNNOp6(TestRNNOp): + +# def set_attrs(self): +# self.num_layers = 2 +# self.is_bidirec = True + +# class TestRNNOp7(TestRNNOp): + +# def set_attrs(self): +# self.num_layers = 2 +# self.is_bidirec = True +# self.is_test = True + +# class TestRNNOp8(TestRNNOp): + +# def set_attrs(self): +# self.num_layers = 2 +# self.is_bidirec = True +# self.sequence_length = None + +# class TestRNNOp9(TestRNNOp): + +# def set_attrs(self): +# self.num_layers = 3 + +if __name__ == '__main__': + unittest.main()