未验证 提交 d4710dfe 编写于 作者: z8hanghuan's avatar z8hanghuan 提交者: GitHub

modify unittest of lstm forward, *test=kunlun (#41534)

* modify unittest of lstm forward, *test=kunlun

* modify unittest of lstm forward, *test=kunlun
上级 0a6fe699
......@@ -36,7 +36,7 @@ ENDIF()
if(NOT DEFINED XPU_BASE_URL)
SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220402")
SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220408")
else()
SET(XPU_BASE_URL "${XPU_BASE_URL}")
endif()
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
......@@ -114,6 +115,9 @@ class RnnXPUKernel : public framework::OpKernel<T> {
if (dropout_mask->numel() != output->numel()) dropout_mask->clear();
}
dropout_mask->mutable_data<uint8_t>(output->dims(), ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
phi::funcs::SetConstant<platform::XPUDeviceContext, uint8_t> ones;
ones(dev_ctx, dropout_mask, static_cast<uint8_t>(1));
PADDLE_ENFORCE_EQ(
mode, "LSTM",
......@@ -190,7 +194,6 @@ class RnnXPUKernel : public framework::OpKernel<T> {
seq_len_tensor = operators::GetDataFromTensor(sequence_length);
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int state_offset = pre_state[0]->dims()[1] * pre_state[0]->dims()[2];
for (int i = 0; i < num_layers; i++) {
......
......@@ -46,8 +46,9 @@ class XPUTestRNNOp(XPUOpTestWrapper):
self.init_dtype()
self.op_type = "rnn"
self.place = paddle.XPUPlace(0)
self.sequence_length = np.ones(
(self.batch_size, ), dtype=np.int32) * self.seq_length
self.sequence_length = np.array([12, 11, 10, 9, 8], dtype=np.int32)
self.num_layers = 1
self.is_bidirec = False
self.set_attrs()
self.mode = "LSTM"
self.is_test = False
......@@ -61,6 +62,10 @@ class XPUTestRNNOp(XPUOpTestWrapper):
high=0.1,
size=(self.seq_length, self.batch_size,
self.input_size)).astype(self.dtype)
input[11][1:][:] = 0
input[10][2:][:] = 0
input[9][3:][:] = 0
input[8][4:][:] = 0
rnn1 = LSTM(
self.input_size,
......@@ -126,10 +131,10 @@ class XPUTestRNNOp(XPUOpTestWrapper):
no_check_set=['Reserve', 'DropoutState'])
def init_size(self):
self.seq_length = 1
self.batch_size = 1
self.input_size = 5
self.hidden_size = 16
self.seq_length = 12
self.batch_size = 5
self.input_size = 3
self.hidden_size = 2
def get_weight_names(self):
weight_names = []
......@@ -142,38 +147,18 @@ class XPUTestRNNOp(XPUOpTestWrapper):
return weight_names
def set_attrs(self):
self.num_layers = 1
self.is_bidirec = False
pass
class TestRNNOp1(TestRNNOp):
def init_size(self):
self.seq_length = 2
self.batch_size = 4
self.input_size = 10
self.hidden_size = 32
def set_attrs(self):
self.num_layers = 1
self.is_bidirec = False
self.sequence_length = None
class TestRNNOp2(TestRNNOp):
def init_size(self):
self.seq_length = 5
self.batch_size = 16
self.input_size = 30
self.hidden_size = 64
def set_attrs(self):
self.num_layers = 1
self.is_bidirec = True
class TestRNNOp3(TestRNNOp):
def init_size(self):
self.seq_length = 10
self.batch_size = 64
self.input_size = 50
self.hidden_size = 64
def set_attrs(self):
self.num_layers = 2
self.is_bidirec = False
......@@ -188,6 +173,17 @@ class XPUTestRNNOp(XPUOpTestWrapper):
self.num_layers = 2
self.is_bidirec = True
class TestRNNOp6(TestRNNOp):
def set_attrs(self):
self.num_layers = 2
self.is_bidirec = True
self.sequence_length = None
class TestRNNOp7(TestRNNOp):
def set_attrs(self):
self.num_layers = 3
self.is_bidirec = True
support_types = get_xpu_op_support_types('rnn')
for stype in support_types:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册