提交 02054094 编写于 作者: T tensor-tang 提交者: ceci3

add jitkernel vcopy and speedup unit test time

test=develop
上级 7e5a4a3d
...@@ -498,6 +498,7 @@ BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, CPUPlace>(); } ...@@ -498,6 +498,7 @@ BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, CPUPlace>(); }
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, CPUPlace>(); } BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, CPUPlace>(); } BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, CPUPlace>(); }
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, CPUPlace>(); } BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, CPUPlace>(); }
BENCH_FP32_CPU(kVCopy) { BenchXYNKernel<jit::kVCopy, T, CPUPlace>(); }
// lstm and peephole // lstm and peephole
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, CPUPlace>(); } BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, CPUPlace>(); }
......
...@@ -36,6 +36,7 @@ const char* to_string(KernelType kt) { ...@@ -36,6 +36,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kVScal); ONE_CASE(kVScal);
ONE_CASE(kVAddBias); ONE_CASE(kVAddBias);
ONE_CASE(kVRelu); ONE_CASE(kVRelu);
ONE_CASE(kVCopy);
ONE_CASE(kVIdentity); ONE_CASE(kVIdentity);
ONE_CASE(kVExp); ONE_CASE(kVExp);
ONE_CASE(kVSquare); ONE_CASE(kVSquare);
......
...@@ -41,6 +41,7 @@ typedef enum { ...@@ -41,6 +41,7 @@ typedef enum {
kVAdd, kVAdd,
kVAddBias, kVAddBias,
kVAddRelu, kVAddRelu,
kVCopy,
kVExp, kVExp,
kVIdentity, kVIdentity,
kVMul, kVMul,
......
...@@ -9,6 +9,7 @@ USE_JITKERNEL_MORE(kVAdd, mkl) ...@@ -9,6 +9,7 @@ USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl) USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl) USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSquare, mkl) USE_JITKERNEL_MORE(kVSquare, mkl)
USE_JITKERNEL_MORE(kVCopy, mkl)
USE_JITKERNEL_MORE(kVSigmoid, mkl) USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl) USE_JITKERNEL_MORE(kVTanh, mkl)
USE_JITKERNEL_MORE(kSeqPool, mkl) USE_JITKERNEL_MORE(kSeqPool, mkl)
......
...@@ -154,6 +154,11 @@ bool VSquareKernel<float>::UseMe(const int& d) const { ...@@ -154,6 +154,11 @@ bool VSquareKernel<float>::UseMe(const int& d) const {
return d > 7; return d > 7;
} }
template <>
bool VCopyKernel<float>::UseMe(const int& d) const {
return d > 15;
}
template <> template <>
bool VSigmoidKernel<float>::UseMe(const int& d) const { bool VSigmoidKernel<float>::UseMe(const int& d) const {
return d > 7; return d > 7;
...@@ -223,6 +228,7 @@ AWALYS_USE_ME_WITH_DOUBLE(VExp); ...@@ -223,6 +228,7 @@ AWALYS_USE_ME_WITH_DOUBLE(VExp);
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid); AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
AWALYS_USE_ME_WITH_DOUBLE(VTanh); AWALYS_USE_ME_WITH_DOUBLE(VTanh);
AWALYS_USE_ME_WITH_DOUBLE(VSquare); AWALYS_USE_ME_WITH_DOUBLE(VSquare);
AWALYS_USE_ME_WITH_DOUBLE(VCopy);
AWALYS_USE_ME_WITH_DOUBLE(Softmax); AWALYS_USE_ME_WITH_DOUBLE(Softmax);
#undef AWALYS_USE_ME_WITH_DOUBLE #undef AWALYS_USE_ME_WITH_DOUBLE
...@@ -244,6 +250,7 @@ REGISTER_MKL_KERNEL(kVAdd, VAdd); ...@@ -244,6 +250,7 @@ REGISTER_MKL_KERNEL(kVAdd, VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal); REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL(kVExp, VExp); REGISTER_MKL_KERNEL(kVExp, VExp);
REGISTER_MKL_KERNEL(kVSquare, VSquare); REGISTER_MKL_KERNEL(kVSquare, VSquare);
REGISTER_MKL_KERNEL(kVCopy, VCopy);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool); REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
......
...@@ -192,6 +192,7 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples); ...@@ -192,6 +192,7 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
DECLARE_MKL_KERNEL(VTanh, XYNTuples); DECLARE_MKL_KERNEL(VTanh, XYNTuples);
DECLARE_MKL_KERNEL(VSquare, XYNTuples); DECLARE_MKL_KERNEL(VSquare, XYNTuples);
DECLARE_MKL_KERNEL(VCopy, XYNTuples);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
......
...@@ -13,6 +13,7 @@ USE_JITKERNEL_REFER(kVAddRelu) ...@@ -13,6 +13,7 @@ USE_JITKERNEL_REFER(kVAddRelu)
USE_JITKERNEL_REFER(kVSub) USE_JITKERNEL_REFER(kVSub)
USE_JITKERNEL_REFER(kVScal) USE_JITKERNEL_REFER(kVScal)
USE_JITKERNEL_REFER(kVAddBias) USE_JITKERNEL_REFER(kVAddBias)
USE_JITKERNEL_REFER(kVCopy)
USE_JITKERNEL_REFER(kVRelu) USE_JITKERNEL_REFER(kVRelu)
USE_JITKERNEL_REFER(kVIdentity) USE_JITKERNEL_REFER(kVIdentity)
USE_JITKERNEL_REFER(kVExp) USE_JITKERNEL_REFER(kVExp)
......
...@@ -30,6 +30,7 @@ REGISTER_REFER_KERNEL(kVScal, VScal); ...@@ -30,6 +30,7 @@ REGISTER_REFER_KERNEL(kVScal, VScal);
REGISTER_REFER_KERNEL(kVAddBias, VAddBias); REGISTER_REFER_KERNEL(kVAddBias, VAddBias);
REGISTER_REFER_KERNEL(kVRelu, VRelu); REGISTER_REFER_KERNEL(kVRelu, VRelu);
REGISTER_REFER_KERNEL(kVCopy, VCopy);
REGISTER_REFER_KERNEL(kVIdentity, VIdentity); REGISTER_REFER_KERNEL(kVIdentity, VIdentity);
REGISTER_REFER_KERNEL(kVSquare, VSquare); REGISTER_REFER_KERNEL(kVSquare, VSquare);
REGISTER_REFER_KERNEL(kVExp, VExp); REGISTER_REFER_KERNEL(kVExp, VExp);
......
...@@ -70,6 +70,11 @@ void VAddBias(const T* a, const T* x, T* y, int n) { ...@@ -70,6 +70,11 @@ void VAddBias(const T* a, const T* x, T* y, int n) {
} }
} }
template <typename T>
void VCopy(const T* x, T* y, int n) {
std::memcpy(y, x, n * sizeof(T));
}
template <typename T> template <typename T>
void VRelu(const T* x, T* y, int n) { void VRelu(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
...@@ -500,6 +505,7 @@ DECLARE_REFER_KERNEL(VExp, XYNTuples); ...@@ -500,6 +505,7 @@ DECLARE_REFER_KERNEL(VExp, XYNTuples);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); DECLARE_REFER_KERNEL(VSigmoid, XYNTuples);
DECLARE_REFER_KERNEL(VTanh, XYNTuples); DECLARE_REFER_KERNEL(VTanh, XYNTuples);
DECLARE_REFER_KERNEL(VSquare, XYNTuples); DECLARE_REFER_KERNEL(VSquare, XYNTuples);
DECLARE_REFER_KERNEL(VCopy, XYNTuples);
// lstm_t*, const lstm_attr_t* // lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples); DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples);
......
...@@ -26,8 +26,8 @@ limitations under the License. */ ...@@ -26,8 +26,8 @@ limitations under the License. */
DEFINE_double(acc, 1e-5, "Test accuracy threshold."); DEFINE_double(acc, 1e-5, "Test accuracy threshold.");
template <typename T> template <typename T>
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f), void RandomVec(const int n, T* a, const T lower = static_cast<T>(-2.f),
const T upper = static_cast<T>(20.f)) { const T upper = static_cast<T>(2.f)) {
static unsigned int seed = 100; static unsigned int seed = 100;
std::mt19937 rng(seed++); std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1); std::uniform_real_distribution<double> uniform_dist(0, 1);
...@@ -514,7 +514,7 @@ void TestKernelXRNTuples() { ...@@ -514,7 +514,7 @@ void TestKernelXRNTuples() {
auto ref = jit::GetRefer<KT, jit::XRNTuples<T>>(); auto ref = jit::GetRefer<KT, jit::XRNTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d); std::vector<T> x(d);
RandomVec<T>(d, x.data(), -2.f, 2.f); RandomVec<T>(d, x.data());
T ref_res; T ref_res;
ref(x.data(), &ref_res, d); ref(x.data(), &ref_res, d);
TestAllImpls<KT, jit::XRNTuples<T>, PlaceType, std::vector<T>, T>(d, x, TestAllImpls<KT, jit::XRNTuples<T>, PlaceType, std::vector<T>, T>(d, x,
...@@ -532,7 +532,7 @@ void TestKernelXYNTuples() { ...@@ -532,7 +532,7 @@ void TestKernelXYNTuples() {
std::vector<T> x(d), yref(d); std::vector<T> x(d), yref(d);
std::vector<T> xinp(d); // inplace test std::vector<T> xinp(d); // inplace test
RandomVec<T>(d, x.data(), -2.f, 2.f); RandomVec<T>(d, x.data());
std::copy(x.begin(), x.end(), xinp.begin()); std::copy(x.begin(), x.end(), xinp.begin());
const T* x_data = x.data(); const T* x_data = x.data();
...@@ -566,7 +566,7 @@ void TestKernelLSTMTuples() { ...@@ -566,7 +566,7 @@ void TestKernelLSTMTuples() {
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> xsrc(4 * d), wp(3 * d), ct_1(d); std::vector<T> xsrc(4 * d), wp(3 * d), ct_1(d);
std::vector<T> ct_ref(d), ht_ref(d), checked(2 * d); std::vector<T> ct_ref(d), ht_ref(d), checked(2 * d);
RandomVec<T>(4 * d, xsrc.data(), -2.f, 2.f); RandomVec<T>(4 * d, xsrc.data());
RandomVec<T>(3 * d, wp.data(), -1.f, 1.f); RandomVec<T>(3 * d, wp.data(), -1.f, 1.f);
RandomVec<T>(d, ct_1.data(), -1.f, 1.f); RandomVec<T>(d, ct_1.data(), -1.f, 1.f);
// x could be changed after compute, so copy to save src // x could be changed after compute, so copy to save src
...@@ -614,8 +614,8 @@ void TestKernelGRUTuples() { ...@@ -614,8 +614,8 @@ void TestKernelGRUTuples() {
auto ref = jit::GetRefer<KT, jit::GRUTuples<T>>(); auto ref = jit::GetRefer<KT, jit::GRUTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> xsrc(3 * d), ht_1(d), ht_ref(d); std::vector<T> xsrc(3 * d), ht_1(d), ht_ref(d);
RandomVec<T>(3 * d, xsrc.data(), -2.f, 2.f); RandomVec<T>(3 * d, xsrc.data());
RandomVec<T>(d, ht_1.data(), -2.f, 2.f); RandomVec<T>(d, ht_1.data());
// x could be changed after compute, so copy to save src // x could be changed after compute, so copy to save src
std::vector<T> x(xsrc.size()); std::vector<T> x(xsrc.size());
std::copy(xsrc.begin(), xsrc.end(), x.begin()); std::copy(xsrc.begin(), xsrc.end(), x.begin());
...@@ -651,7 +651,7 @@ void TestKernelSeqPoolTuples() { ...@@ -651,7 +651,7 @@ void TestKernelSeqPoolTuples() {
auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>(); auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(h * w), yref(w); std::vector<T> x(h * w), yref(w);
RandomVec<T>(h * w, x.data(), -2.f, 2.f); RandomVec<T>(h * w, x.data());
const T* x_data = x.data(); const T* x_data = x.data();
T* yref_data = yref.data(); T* yref_data = yref.data();
ref(x_data, yref_data, &attr); ref(x_data, yref_data, &attr);
...@@ -676,8 +676,8 @@ void TestKernelMatMulTuples() { ...@@ -676,8 +676,8 @@ void TestKernelMatMulTuples() {
auto ref = jit::GetRefer<KT, jit::MatMulTuples<T>>(); auto ref = jit::GetRefer<KT, jit::MatMulTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> a(m * k), b(k * n), c(m * n); std::vector<T> a(m * k), b(k * n), c(m * n);
RandomVec<T>(m * k, a.data(), -2.f, 2.f); RandomVec<T>(m * k, a.data());
RandomVec<T>(k * n, b.data(), -2.f, 2.f); RandomVec<T>(k * n, b.data());
const T* a_data = a.data(); const T* a_data = a.data();
const T* b_data = b.data(); const T* b_data = b.data();
T* c_data = c.data(); T* c_data = c.data();
...@@ -699,7 +699,7 @@ void TestKernelSoftmaxTuples() { ...@@ -699,7 +699,7 @@ void TestKernelSoftmaxTuples() {
auto ref = jit::GetRefer<KT, jit::SoftmaxTuples<T>>(); auto ref = jit::GetRefer<KT, jit::SoftmaxTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(bs * n), y(bs * n); std::vector<T> x(bs * n), y(bs * n);
RandomVec<T>(bs * n, x.data(), -2.f, 2.f); RandomVec<T>(bs * n, x.data());
const T* x_data = x.data(); const T* x_data = x.data();
T* y_data = y.data(); T* y_data = y.data();
...@@ -726,7 +726,7 @@ void TestKernelEmbSeqPoolTuples() { ...@@ -726,7 +726,7 @@ void TestKernelEmbSeqPoolTuples() {
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (int tbl_w : test_sizes) { for (int tbl_w : test_sizes) {
std::vector<T> table(tbl_h * tbl_w); std::vector<T> table(tbl_h * tbl_w);
RandomVec<T>(tbl_h * tbl_w, table.data(), -2.f, 2.f); RandomVec<T>(tbl_h * tbl_w, table.data());
const T* table_data = table.data(); const T* table_data = table.data();
for (auto type : pool_types) { for (auto type : pool_types) {
for (int idx_w : {1, 2, 10, 16}) { for (int idx_w : {1, 2, 10, 16}) {
...@@ -772,14 +772,14 @@ void TestKernelSgdTuples() { ...@@ -772,14 +772,14 @@ void TestKernelSgdTuples() {
for (int grad_w : TestSizes()) { for (int grad_w : TestSizes()) {
std::vector<T> param(param_h * grad_w); std::vector<T> param(param_h * grad_w);
std::vector<T> param_out(param_h * grad_w); std::vector<T> param_out(param_h * grad_w);
RandomVec<T>(param_h * grad_w, param.data(), -2.f, 2.f); RandomVec<T>(param_h * grad_w, param.data());
const T* param_data = param.data(); const T* param_data = param.data();
T* out_data = param_out.data(); T* out_data = param_out.data();
for (int rows_size = 1; rows_size <= param_h; ++rows_size) { for (int rows_size = 1; rows_size <= param_h; ++rows_size) {
std::vector<T> grad(rows_size * grad_w); std::vector<T> grad(rows_size * grad_w);
std::vector<int64_t> rows = std::vector<int64_t> rows =
UnDuplicatedRandomVec(rows_size, 0, rows_size - 1); UnDuplicatedRandomVec(rows_size, 0, rows_size - 1);
RandomVec<T>(rows_size * grad_w, grad.data(), -2.f, 2.f); RandomVec<T>(rows_size * grad_w, grad.data());
const int64_t* rows_data = rows.data(); const int64_t* rows_data = rows.data();
const T* grad_data = grad.data(); const T* grad_data = grad.data();
auto ref = jit::GetRefer<KT, jit::SgdTuples<T>>(); auto ref = jit::GetRefer<KT, jit::SgdTuples<T>>();
...@@ -815,8 +815,8 @@ void TestKernelNCHW16CMulNCTuples() { ...@@ -815,8 +815,8 @@ void TestKernelNCHW16CMulNCTuples() {
int sz = n * c * h * w; int sz = n * c * h * w;
std::vector<T> x(sz), y(n * c), zref(sz); std::vector<T> x(sz), y(n * c), zref(sz);
std::vector<T> ztgt(sz), zjit(sz); std::vector<T> ztgt(sz), zjit(sz);
RandomVec<T>(sz, x.data(), -2.f, 2.f); RandomVec<T>(sz, x.data());
RandomVec<T>(n * c, y.data(), -2.f, 2.f); RandomVec<T>(n * c, y.data());
const T* x_data = x.data(); const T* x_data = x.data();
const T* y_data = y.data(); const T* y_data = y.data();
...@@ -873,11 +873,11 @@ void TestKernelLayerNormTuples() { ...@@ -873,11 +873,11 @@ void TestKernelLayerNormTuples() {
int sz = left * right; int sz = left * right;
std::vector<T> x(sz), mean(left), var(left), scale(right), bias(right), std::vector<T> x(sz), mean(left), var(left), scale(right), bias(right),
outref(sz); outref(sz);
RandomVec<T>(sz, x.data(), -2.f, 2.f); RandomVec<T>(sz, x.data());
RandomVec<T>(left, mean.data(), -2.f, 2.f); RandomVec<T>(left, mean.data());
RandomVec<T>(left, var.data(), -2.f, 2.f); RandomVec<T>(left, var.data());
RandomVec<T>(right, scale.data(), -2.f, 2.f); RandomVec<T>(right, scale.data());
RandomVec<T>(right, bias.data(), -2.f, 2.f); RandomVec<T>(right, bias.data());
const T* scale_data = scale.data(); const T* scale_data = scale.data();
const T* bias_data = bias.data(); const T* bias_data = bias.data();
...@@ -903,7 +903,7 @@ void TestKernelCRFDecodingTuples() { ...@@ -903,7 +903,7 @@ void TestKernelCRFDecodingTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
constexpr int state_trans_base_idx = 2; constexpr int state_trans_base_idx = 2;
auto test_sizes = TestSizes(); auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 2000));
for (int seq_len : {1, 11, 17, 50}) { for (int seq_len : {1, 11, 17, 50}) {
for (int tag_num : test_sizes) { for (int tag_num : test_sizes) {
auto ref = jit::GetRefer<KT, jit::CRFDecodingTuples<T>>(); auto ref = jit::GetRefer<KT, jit::CRFDecodingTuples<T>>();
...@@ -912,8 +912,8 @@ void TestKernelCRFDecodingTuples() { ...@@ -912,8 +912,8 @@ void TestKernelCRFDecodingTuples() {
int w_sz = (tag_num + state_trans_base_idx) * tag_num; int w_sz = (tag_num + state_trans_base_idx) * tag_num;
std::vector<T> x(x_sz), w(w_sz), alpharef(x_sz); std::vector<T> x(x_sz), w(w_sz), alpharef(x_sz);
std::vector<int> trackref(x_sz); std::vector<int> trackref(x_sz);
RandomVec<T>(x_sz, x.data(), -2.f, 2.f); RandomVec<T>(x_sz, x.data());
RandomVec<T>(w_sz, w.data(), -2.f, 2.f); RandomVec<T>(w_sz, w.data());
ref(seq_len, (const T*)x.data(), (const T*)w.data(), alpharef.data(), ref(seq_len, (const T*)x.data(), (const T*)w.data(), alpharef.data(),
trackref.data(), tag_num); trackref.data(), tag_num);
...@@ -949,6 +949,7 @@ TEST_CPU_KERNEL(XYNTuples, kVSquare); ...@@ -949,6 +949,7 @@ TEST_CPU_KERNEL(XYNTuples, kVSquare);
TEST_CPU_KERNEL(XYNTuples, kVExp); TEST_CPU_KERNEL(XYNTuples, kVExp);
TEST_CPU_KERNEL(XYNTuples, kVSigmoid); TEST_CPU_KERNEL(XYNTuples, kVSigmoid);
TEST_CPU_KERNEL(XYNTuples, kVTanh); TEST_CPU_KERNEL(XYNTuples, kVTanh);
TEST_CPU_KERNEL(XYNTuples, kVCopy);
TEST_CPU_KERNEL(LSTMTuples, kLSTMCtHt); TEST_CPU_KERNEL(LSTMTuples, kLSTMCtHt);
TEST_CPU_KERNEL(LSTMTuples, kLSTMC1H1); TEST_CPU_KERNEL(LSTMTuples, kLSTMC1H1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册