未验证 提交 e5f9d3a4 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #15892 from tensor-tang/jit/sgd

refine sgd op
...@@ -332,6 +332,45 @@ void BenchEmbSeqPoolKernel() { ...@@ -332,6 +332,45 @@ void BenchEmbSeqPoolKernel() {
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchSgdKernel() {
const T lr = 0.1;
auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
const int64_t upper) -> std::vector<int64_t> {
PADDLE_ENFORCE_LE(static_cast<size_t>(upper - lower), n - 1);
PADDLE_ENFORCE_GT(n, 0);
std::vector<int64_t> all, out;
for (int i = 0; i < n; ++i) {
all.push_back(i);
}
std::random_shuffle(all.begin(), all.end());
out.insert(out.begin(), all.begin(), all.begin() + n);
return out;
};
for (int param_h : {1, 1000}) {
for (int grad_w : {1, 2, 8, 16, 30, 256}) {
// only benchmark inplace
Tensor param;
param.Resize({param_h, grad_w});
T* param_data = param.mutable_data<T>(PlaceType());
RandomVec<T>(param_h * grad_w, param_data, -2.f, 2.f);
for (int rows_size = 1; rows_size <= std::min(param_h, 10); ++rows_size) {
Tensor grad;
grad.Resize({rows_size, grad_w});
std::vector<int64_t> rows =
UnDuplicatedRandomVec(rows_size, 0, rows_size - 1);
RandomVec<T>(rows_size * grad_w, grad.mutable_data<T>(PlaceType()),
-2.f, 2.f);
const T* grad_data = grad.data<T>();
const int64_t* rows_data = rows.data();
jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size);
BenchAllImpls<KT, jit::SgdTuples<T>, PlaceType>(
attr, &lr, param_data, grad_data, rows_data, param_data, &attr);
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void BenchMatMulKernel() { void BenchMatMulKernel() {
for (int m : {1, 2, 3, 4}) { for (int m : {1, 2, 3, 4}) {
...@@ -477,6 +516,9 @@ BENCH_FP32_CPU(kEmbSeqPool) { ...@@ -477,6 +516,9 @@ BENCH_FP32_CPU(kEmbSeqPool) {
BenchEmbSeqPoolKernel<jit::kEmbSeqPool, T, CPUPlace>(); BenchEmbSeqPoolKernel<jit::kEmbSeqPool, T, CPUPlace>();
} }
// sgd function
BENCH_FP32_CPU(kSgd) { BenchSgdKernel<jit::kSgd, T, CPUPlace>(); }
// matmul // matmul
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); } BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); }
......
...@@ -32,3 +32,4 @@ USE_JITKERNEL_GEN(kSeqPool) ...@@ -32,3 +32,4 @@ USE_JITKERNEL_GEN(kSeqPool)
USE_JITKERNEL_GEN(kHMax) USE_JITKERNEL_GEN(kHMax)
USE_JITKERNEL_GEN(kHSum) USE_JITKERNEL_GEN(kHSum)
USE_JITKERNEL_GEN(kEmbSeqPool) USE_JITKERNEL_GEN(kEmbSeqPool)
USE_JITKERNEL_GEN(kSgd)
...@@ -31,7 +31,8 @@ namespace gen { ...@@ -31,7 +31,8 @@ namespace gen {
// Application Binary Interface // Application Binary Interface
constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI), constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX), abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX),
abi_param4(Xbyak::Operand::RCX); abi_param4(Xbyak::Operand::RCX), abi_param5(Xbyak::Operand::R8),
abi_param6(Xbyak::Operand::R9);
constexpr Xbyak::Operand::Code g_abi_regs[] = { constexpr Xbyak::Operand::Code g_abi_regs[] = {
Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12, Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/sgd.h"
#include <stddef.h> // offsetof
#include <vector>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void SgdJitCode::genCode() {
preCode();
constexpr int block = YMM_FLOAT_BLOCK;
constexpr int max_num_regs = 7;
const int num_block = w_ / block;
const int num_groups = num_block / max_num_regs;
const size_t block_size = sizeof(float) * block;
const size_t width_size = w_ * sizeof(float);
std::vector<int> groups(num_groups, max_num_regs);
int rest_num_regs = num_block % max_num_regs;
if (rest_num_regs > 0) {
groups.push_back(rest_num_regs);
}
vbroadcastss(ymm_lr, ptr[param_lr]);
// protect rdx
mov(reg_ptr_grad_i, param_grad);
mov(reg_ptr_rows_i, param_rows);
mov(reg_rows_size_in_byte,
qword[param_attr + offsetof(sgd_attr_t, selected_rows_size)]);
mov(rax, sizeof(int64_t));
mul(reg_rows_size_in_byte);
mov(reg_rows_size_in_byte, rax);
add(reg_rows_size_in_byte, reg_ptr_rows_i);
Label l_next_row;
L(l_next_row);
{
mov(reg_row, qword[reg_ptr_rows_i]);
mov(rax, width_size);
mul(reg_row);
mov(reg_row, rax);
mov(reg_ptr_param_i, param_param);
mov(reg_ptr_out_i, param_out);
add(reg_ptr_param_i, reg_row);
add(reg_ptr_out_i, reg_row);
size_t w_offset = 0;
for (int num_regs : groups) {
// load grad
size_t inner_offfset = w_offset;
for (int reg_i = 0; reg_i < num_regs; ++reg_i) {
vmovups(ymm_t(reg_i), ptr[reg_ptr_grad_i + inner_offfset]);
inner_offfset += block_size;
}
// load param
inner_offfset = w_offset;
for (int reg_i = 0; reg_i < num_regs; ++reg_i) {
vmovups(ymm_t(reg_i + num_regs), ptr[reg_ptr_param_i + inner_offfset]);
inner_offfset += block_size;
}
// compute out
for (int reg_i = 0; reg_i < num_regs; ++reg_i) {
vmulps(ymm_t(reg_i), ymm_t(reg_i), ymm_lr);
vsubps(ymm_t(reg_i + num_regs), ymm_t(reg_i + num_regs), ymm_t(reg_i));
}
// save out
inner_offfset = w_offset;
for (int reg_i = 0; reg_i < num_regs; ++reg_i) {
vmovups(ptr[reg_ptr_out_i + inner_offfset], ymm_t(reg_i + num_regs));
inner_offfset += block_size;
}
w_offset += (block_size * num_regs);
}
add(reg_ptr_grad_i, width_size);
add(reg_ptr_rows_i, sizeof(int64_t));
cmp(reg_ptr_rows_i, reg_rows_size_in_byte);
jl(l_next_row, T_NEAR);
}
postCode();
}
class SgdCreator : public JitCodeCreator<sgd_attr_t> {
public:
bool UseMe(const sgd_attr_t& attr) const override {
return platform::MayIUse(platform::avx) &&
attr.grad_width % YMM_FLOAT_BLOCK == 0;
}
size_t CodeSize(const sgd_attr_t& attr) const override {
return 96 + (attr.grad_width / YMM_FLOAT_BLOCK) * 32 * 8;
}
std::unique_ptr<GenBase> CreateJitCode(
const sgd_attr_t& attr) const override {
PADDLE_ENFORCE_EQ(attr.param_width, attr.grad_width);
PADDLE_ENFORCE_LE(attr.selected_rows_size, attr.grad_height);
PADDLE_ENFORCE_GE(attr.selected_rows_size, 0);
return make_unique<SgdJitCode>(attr, CodeSize(attr));
}
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kSgd, gen::SgdCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class SgdJitCode : public JitCode {
public:
explicit SgdJitCode(const sgd_attr_t& attr, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), w_(attr.grad_width) {
this->genCode();
}
DECLARE_JIT_CODE(SgdJitCode);
void genCode() override;
private:
int w_;
reg64_t param_lr{abi_param1};
reg64_t param_param{abi_param2};
reg64_t param_grad{abi_param3};
reg64_t param_rows{abi_param4};
reg64_t param_out{abi_param5};
reg64_t param_attr{abi_param6};
ymm_t ymm_lr = ymm_t(15);
reg64_t reg_ptr_grad_i{r10};
reg64_t reg_ptr_rows_i{r11};
reg64_t reg_rows_size_in_byte{r12};
reg64_t reg_row{r13};
reg64_t reg_ptr_param_i{r14};
reg64_t reg_ptr_out_i{r15};
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
...@@ -55,6 +55,7 @@ const char* to_string(KernelType kt) { ...@@ -55,6 +55,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kHSum); ONE_CASE(kHSum);
ONE_CASE(kSoftmax); ONE_CASE(kSoftmax);
ONE_CASE(kEmbSeqPool); ONE_CASE(kEmbSeqPool);
ONE_CASE(kSgd);
default: default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt); PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel"; return "NOT JITKernel";
......
...@@ -181,6 +181,14 @@ inline std::ostream& operator<<(std::ostream& os, ...@@ -181,6 +181,14 @@ inline std::ostream& operator<<(std::ostream& os,
return os; return os;
} }
inline std::ostream& operator<<(std::ostream& os, const sgd_attr_t& attr) {
os << "param_height[" << attr.param_height << "],param_width["
<< attr.param_width << "],grad_height[" << attr.grad_height
<< "],grad_width[" << attr.grad_width << "],selected_rows_size["
<< attr.selected_rows_size << "]";
return os;
}
inline std::ostream& operator<<(std::ostream& os, const matmul_attr_t& attr) { inline std::ostream& operator<<(std::ostream& os, const matmul_attr_t& attr) {
os << "M[" << attr.m << "],N[" << attr.n << "],K[" << attr.k << "]"; os << "M[" << attr.m << "],N[" << attr.n << "],K[" << attr.k << "]";
return os; return os;
......
...@@ -46,6 +46,7 @@ typedef enum { ...@@ -46,6 +46,7 @@ typedef enum {
kVMul, kVMul,
kVRelu, kVRelu,
kVScal, kVScal,
kSgd,
kVSigmoid, kVSigmoid,
kVSquare, kVSquare,
kVSub, kVSub,
...@@ -173,6 +174,28 @@ struct EmbSeqPoolTuples { ...@@ -173,6 +174,28 @@ struct EmbSeqPoolTuples {
const emb_seq_pool_attr_t*); const emb_seq_pool_attr_t*);
}; };
typedef struct sgd_attr_s {
int64_t param_height, param_width;
int64_t grad_height, grad_width;
int64_t selected_rows_size;
sgd_attr_s() = default;
explicit sgd_attr_s(int64_t param_h, int64_t param_w, int64_t grad_h,
int64_t grad_w, int64_t selected_rows_sz)
: param_height(param_h),
param_width(param_w),
grad_height(grad_h),
grad_width(grad_w),
selected_rows_size(selected_rows_sz) {}
} sgd_attr_t;
template <typename T>
struct SgdTuples {
typedef T data_type;
typedef sgd_attr_t attr_type;
typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*,
const sgd_attr_t*);
};
typedef struct matmul_attr_s { typedef struct matmul_attr_s {
int m, n, k; int m, n, k;
void* packed_weight{nullptr}; void* packed_weight{nullptr};
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h" #include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -23,14 +24,30 @@ size_t JitCodeKey<int>(const int& d) { ...@@ -23,14 +24,30 @@ size_t JitCodeKey<int>(const int& d) {
return d; return d;
} }
// TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr int act_type_shift = 3; // suppot 2^3 act types constexpr int act_type_shift = 3; // suppot 2^3 act types
static inline int act_type_convert(KernelType type) {
if (type == kVIdentity) {
return 0;
} else if (type == kVExp) {
return 1;
} else if (type == kVRelu) {
return 2;
} else if (type == kVSigmoid) {
return 3;
} else if (type == kVTanh) {
return 4;
}
PADDLE_THROW("Unsupported act type %d", type);
return 0;
}
template <> template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) { size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
size_t key = attr.d; size_t key = attr.d;
int gate_key = static_cast<int>(attr.act_gate) << 1; int gate_key = act_type_convert(attr.act_gate) << 1;
int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift); int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift);
int cell_key = static_cast<int>(attr.act_cell) << (1 + act_type_shift * 2); int cell_key = act_type_convert(attr.act_cell) << (1 + act_type_shift * 2);
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key + return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole; attr.use_peephole;
} }
...@@ -38,8 +55,8 @@ size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) { ...@@ -38,8 +55,8 @@ size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
template <> template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) { size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
size_t key = attr.d; size_t key = attr.d;
return (key << (act_type_shift * 2)) + static_cast<int>(attr.act_gate) + return (key << (act_type_shift * 2)) + act_type_convert(attr.act_gate) +
(static_cast<int>(attr.act_cand) << act_type_shift); (act_type_convert(attr.act_cand) << act_type_shift);
} }
template <> template <>
...@@ -61,6 +78,11 @@ size_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) { ...@@ -61,6 +78,11 @@ size_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
return attr.table_width; return attr.table_width;
} }
template <>
size_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
return attr.grad_width;
}
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -14,3 +14,4 @@ USE_JITKERNEL_MORE(kVTanh, mkl) ...@@ -14,3 +14,4 @@ USE_JITKERNEL_MORE(kVTanh, mkl)
USE_JITKERNEL_MORE(kSeqPool, mkl) USE_JITKERNEL_MORE(kSeqPool, mkl)
USE_JITKERNEL_MORE(kSoftmax, mkl) USE_JITKERNEL_MORE(kSoftmax, mkl)
USE_JITKERNEL_MORE(kEmbSeqPool, mkl) USE_JITKERNEL_MORE(kEmbSeqPool, mkl)
USE_JITKERNEL_MORE(kSgd, mkl)
...@@ -184,6 +184,16 @@ bool EmbSeqPoolKernel<double>::UseMe(const emb_seq_pool_attr_t& attr) const { ...@@ -184,6 +184,16 @@ bool EmbSeqPoolKernel<double>::UseMe(const emb_seq_pool_attr_t& attr) const {
return true; return true;
} }
template <>
bool SgdKernel<float>::UseMe(const sgd_attr_t& attr) const {
return true;
}
template <>
bool SgdKernel<double>::UseMe(const sgd_attr_t& attr) const {
return true;
}
template <> template <>
bool MatMulKernel<float>::UseMe(const matmul_attr_t& attr) const { bool MatMulKernel<float>::UseMe(const matmul_attr_t& attr) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
...@@ -239,5 +249,6 @@ REGISTER_MKL_KERNEL(kVTanh, VTanh); ...@@ -239,5 +249,6 @@ REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool); REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool); REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_MKL_KERNEL(kSoftmax, Softmax); REGISTER_MKL_KERNEL(kSoftmax, Softmax);
REGISTER_MKL_KERNEL(kSgd, Sgd);
#undef REGISTER_MKL_KERNEL #undef REGISTER_MKL_KERNEL
...@@ -142,6 +142,32 @@ void Softmax(const T* x, T* y, int n, int bs) { ...@@ -142,6 +142,32 @@ void Softmax(const T* x, T* y, int n, int bs) {
} }
} }
template <typename T>
void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
T* out, const sgd_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width);
PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height);
T scalar = -lr[0];
int width = attr->grad_width;
if (out == param) {
for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, attr->param_height);
PADDLE_ENFORCE_GE(h_idx, 0);
VAXPY(scalar, grad + i * width, out + h_idx * width, width);
}
} else {
for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, attr->param_height);
PADDLE_ENFORCE_GE(h_idx, 0);
VScal(&scalar, grad + i * width, out + h_idx * width, width);
VAdd(param + h_idx * width, out + h_idx * width, out + h_idx * width,
width);
}
}
}
#define DECLARE_MKL_KERNEL(name, tuples) \ #define DECLARE_MKL_KERNEL(name, tuples) \
template <typename T> \ template <typename T> \
class name##Kernel : public KernelMore<tuples<T>> { \ class name##Kernel : public KernelMore<tuples<T>> { \
...@@ -173,6 +199,8 @@ DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples); ...@@ -173,6 +199,8 @@ DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples); DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples);
DECLARE_MKL_KERNEL(Sgd, SgdTuples);
#undef DECLARE_MKL_KERNEL #undef DECLARE_MKL_KERNEL
} // namespace mkl } // namespace mkl
......
...@@ -33,3 +33,4 @@ USE_JITKERNEL_REFER(kHSum) ...@@ -33,3 +33,4 @@ USE_JITKERNEL_REFER(kHSum)
USE_JITKERNEL_REFER(kHMax) USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kSoftmax) USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool) USE_JITKERNEL_REFER(kEmbSeqPool)
USE_JITKERNEL_REFER(kSgd)
...@@ -59,4 +59,6 @@ REGISTER_REFER_KERNEL(kSoftmax, Softmax); ...@@ -59,4 +59,6 @@ REGISTER_REFER_KERNEL(kSoftmax, Softmax);
REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool); REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_REFER_KERNEL(kSgd, Sgd);
#undef REGISTER_REFER_KERNEL #undef REGISTER_REFER_KERNEL
...@@ -446,6 +446,36 @@ void EmbSeqPool(const T* table, const int64_t* idx, T* out, ...@@ -446,6 +446,36 @@ void EmbSeqPool(const T* table, const int64_t* idx, T* out,
} }
} }
// SGD algorithm:
// lr is pointor of learning rate scalar
// param is an input matrix with (param_h, param_w)
// grad is an input matrix with (grad_h, grad_w), here grad_w == param_w
// selected_rows is a vectot<int64_t> with size selected_rows_size( <= grad_h )
// out is an output matrix with (param_h, param_w)
//
// support both regular and sparse grad
// regular SGD: out[:] = param[:] - lr[0] * grad[:];
// sparse SGD: out[rows[i]][:] = param[rows[i]][:] - lr[0] * grad[i][:]
//
// Note: when use sparse SGD, and if out != param,
// the out rows which are not selected have not beed changed, which maybe empty
template <typename T>
void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
T* out, const sgd_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width);
PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height);
for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, attr->param_height);
PADDLE_ENFORCE_GE(h_idx, 0);
for (int64_t j = 0; j < attr->grad_width; ++j) {
out[h_idx * attr->grad_width + j] =
param[h_idx * attr->grad_width + j] -
lr[0] * grad[i * attr->grad_width + j];
}
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \ #define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \ template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \ class name##Kernel : public ReferKernel<tuples<T>> { \
...@@ -496,6 +526,8 @@ DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples); ...@@ -496,6 +526,8 @@ DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples);
DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples); DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_REFER_KERNEL(Sgd, SgdTuples);
#undef DECLARE_REFER_KERNEL #undef DECLARE_REFER_KERNEL
} // namespace refer } // namespace refer
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include <random> #include <random>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -36,14 +37,14 @@ void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f), ...@@ -36,14 +37,14 @@ void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
} }
template <typename T> template <typename T>
void ExpectEQ(const T* target, const T* refer, int n) { void ExpectEQ(const T* target, const T* refer, size_t n) {
if (std::is_floating_point<T>::value) { if (std::is_floating_point<T>::value) {
for (int i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
EXPECT_NEAR(target[i], refer[i], FLAGS_acc); EXPECT_NEAR(target[i], refer[i], FLAGS_acc) << " at index : " << i;
} }
} else { } else {
for (int i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
EXPECT_EQ(target[i], refer[i]); EXPECT_EQ(target[i], refer[i]) << " at index : " << i;
} }
} }
} }
...@@ -296,6 +297,45 @@ struct TestFuncWithRefer<jit::EmbSeqPoolTuples<T>, std::vector<T>, ...@@ -296,6 +297,45 @@ struct TestFuncWithRefer<jit::EmbSeqPoolTuples<T>, std::vector<T>,
} }
}; };
template <typename T>
struct TestFuncWithRefer<jit::SgdTuples<T>, T, std::vector<T>, std::vector<T>,
std::vector<int64_t>, std::vector<T>,
typename jit::SgdTuples<T>::attr_type> {
void operator()(const typename jit::SgdTuples<T>::func_type tgt, const T lr,
const std::vector<T>& param, const std::vector<T>& grad,
const std::vector<int64_t>& rows, const std::vector<T>& oref,
const typename jit::SgdTuples<T>::attr_type& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(param.size(),
static_cast<size_t>(attr.param_height * attr.param_width));
EXPECT_EQ(grad.size(),
static_cast<size_t>(attr.grad_height * attr.grad_width));
EXPECT_EQ(rows.size(), static_cast<size_t>(attr.selected_rows_size));
EXPECT_EQ(param.size(), oref.size());
const T* param_data = param.data();
const T* grad_data = grad.data();
const int64_t* rows_data = rows.data();
const T* oref_data = oref.data();
std::vector<T> out(oref.size());
T* o_data = out.data();
tgt(&lr, param_data, grad_data, rows_data, o_data, &attr);
// only the selected rows should be equal
for (size_t i = 0; i < rows.size(); ++i) {
ExpectEQ<T>(o_data + rows[i] * attr.grad_width,
oref_data + rows[i] * attr.grad_width, attr.grad_width);
}
// inplace
std::copy(param.begin(), param.end(), out.begin());
tgt(&lr, o_data, grad_data, rows_data, o_data, &attr);
for (size_t i = 0; i < rows.size(); ++i) {
ExpectEQ<T>(o_data + rows[i] * attr.grad_width,
oref_data + rows[i] * attr.grad_width, attr.grad_width);
}
}
};
template <typename T> template <typename T>
struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>, struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>,
std::vector<T>, std::vector<T>,
...@@ -407,7 +447,7 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { ...@@ -407,7 +447,7 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestXYZNKernel() { void TestKernelXYZNTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) { for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XYZNTuples<T>>(); auto ref = jit::GetRefer<KT, jit::XYZNTuples<T>>();
...@@ -440,7 +480,7 @@ void TestXYZNKernel() { ...@@ -440,7 +480,7 @@ void TestXYZNKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestAXYNKernel() { void TestKernelAXYNTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) { for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::AXYNTuples<T>>(); auto ref = jit::GetRefer<KT, jit::AXYNTuples<T>>();
...@@ -466,7 +506,7 @@ void TestAXYNKernel() { ...@@ -466,7 +506,7 @@ void TestAXYNKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestXRNKernel() { void TestKernelXRNTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
auto last_acc = FLAGS_acc; auto last_acc = FLAGS_acc;
FLAGS_acc = 1e-4; FLAGS_acc = 1e-4;
...@@ -484,7 +524,7 @@ void TestXRNKernel() { ...@@ -484,7 +524,7 @@ void TestXRNKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestXYNKernel() { void TestKernelXYNTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) { for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XYNTuples<T>>(); auto ref = jit::GetRefer<KT, jit::XYNTuples<T>>();
...@@ -509,10 +549,12 @@ void TestXYNKernel() { ...@@ -509,10 +549,12 @@ void TestXYNKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestLSTMKernel() { void TestKernelLSTMTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"}; std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
for (int d : TestSizes()) { auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (int d : test_sizes) {
for (bool use_peephole : {true, false}) { for (bool use_peephole : {true, false}) {
for (auto& act_gate : all_acts) { for (auto& act_gate : all_acts) {
for (auto& act_cand : all_acts) { for (auto& act_cand : all_acts) {
...@@ -559,10 +601,12 @@ void TestLSTMKernel() { ...@@ -559,10 +601,12 @@ void TestLSTMKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestGRUKernel() { void TestKernelGRUTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"}; std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
for (int d : TestSizes()) { auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (int d : test_sizes) {
for (auto& act_gate : all_acts) { for (auto& act_gate : all_acts) {
for (auto& act_cand : all_acts) { for (auto& act_cand : all_acts) {
const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate), const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate),
...@@ -593,14 +637,16 @@ void TestGRUKernel() { ...@@ -593,14 +637,16 @@ void TestGRUKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestSeqPoolKernel() { void TestKernelSeqPoolTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<jit::SeqPoolType> pool_types = { std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt};
auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (auto type : pool_types) { for (auto type : pool_types) {
for (int w : TestSizes()) { for (int w : test_sizes) {
jit::seq_pool_attr_t attr(w, type); jit::seq_pool_attr_t attr(w, type);
for (int h : TestSizes()) { for (int h : test_sizes) {
attr.h = h; attr.h = h;
auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>(); auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
...@@ -618,11 +664,11 @@ void TestSeqPoolKernel() { ...@@ -618,11 +664,11 @@ void TestSeqPoolKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestMatMulKernel() { void TestKernelMatMulTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
auto last_acc = FLAGS_acc; auto last_acc = FLAGS_acc;
// TODO(intel): fix MKL acc issue // export MKL_CBWR=AVX would make MKL force to use AVX
// https://github.com/PaddlePaddle/Paddle/issues/15447 // export KMP_DETERMINISTIC_REDUCTION=yes would make the result deterministic
FLAGS_acc = 1e-3; FLAGS_acc = 1e-3;
for (int m : {1, 2, 3, 4}) { for (int m : {1, 2, 3, 4}) {
for (int n : {1, 2, 3, 4}) { for (int n : {1, 2, 3, 4}) {
...@@ -646,7 +692,7 @@ void TestMatMulKernel() { ...@@ -646,7 +692,7 @@ void TestMatMulKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestSoftmaxKernel() { void TestKernelSoftmaxTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int bs : {1, 2, 10}) { for (int bs : {1, 2, 10}) {
for (int n : TestSizes()) { for (int n : TestSizes()) {
...@@ -671,12 +717,14 @@ void TestSoftmaxKernel() { ...@@ -671,12 +717,14 @@ void TestSoftmaxKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestEmbSeqPoolKernel() { void TestKernelEmbSeqPoolTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
int64_t tbl_h = 1e4; int64_t tbl_h = 1e4;
std::vector<jit::SeqPoolType> pool_types = { std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum}; // only support sum yet jit::SeqPoolType::kSum}; // only support sum yet
for (int tbl_w : TestSizes()) { auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (int tbl_w : test_sizes) {
std::vector<T> table(tbl_h * tbl_w); std::vector<T> table(tbl_h * tbl_w);
RandomVec<T>(tbl_h * tbl_w, table.data(), -2.f, 2.f); RandomVec<T>(tbl_h * tbl_w, table.data(), -2.f, 2.f);
const T* table_data = table.data(); const T* table_data = table.data();
...@@ -705,7 +753,61 @@ void TestEmbSeqPoolKernel() { ...@@ -705,7 +753,61 @@ void TestEmbSeqPoolKernel() {
} }
template <jit::KernelType KT, typename T, typename PlaceType> template <jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() { void TestKernelSgdTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
const T lr = 0.1;
auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
const int64_t upper) -> std::vector<int64_t> {
PADDLE_ENFORCE_LE(static_cast<size_t>(upper - lower), n - 1);
PADDLE_ENFORCE_GT(n, 0);
std::vector<int64_t> all, out;
for (int i = 0; i < n; ++i) {
all.push_back(i);
}
std::random_shuffle(all.begin(), all.end());
out.insert(out.begin(), all.begin(), all.begin() + n);
return out;
};
for (int param_h : {1, 10}) {
for (int grad_w : TestSizes()) {
std::vector<T> param(param_h * grad_w);
std::vector<T> param_out(param_h * grad_w);
RandomVec<T>(param_h * grad_w, param.data(), -2.f, 2.f);
const T* param_data = param.data();
T* out_data = param_out.data();
for (int rows_size = 1; rows_size <= param_h; ++rows_size) {
std::vector<T> grad(rows_size * grad_w);
std::vector<int64_t> rows =
UnDuplicatedRandomVec(rows_size, 0, rows_size - 1);
RandomVec<T>(rows_size * grad_w, grad.data(), -2.f, 2.f);
const int64_t* rows_data = rows.data();
const T* grad_data = grad.data();
auto ref = jit::GetRefer<KT, jit::SgdTuples<T>>();
EXPECT_TRUE(ref != nullptr);
jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size);
ref(&lr, param_data, grad_data, rows_data, out_data, &attr);
// inplace test
std::vector<T> inp(param.size());
std::copy(param.begin(), param.end(), inp.begin());
T* inp_data = inp.data();
ref(&lr, inp_data, grad_data, rows_data, inp_data, &attr);
// only the selected rows should be equal
for (int i = 0; i < rows_size; ++i) {
ExpectEQ<T>(inp_data + rows[i] * grad_w, out_data + rows[i] * grad_w,
grad_w);
}
TestAllImpls<KT, jit::SgdTuples<T>, PlaceType, T, std::vector<T>,
std::vector<T>, std::vector<int64_t>, std::vector<T>>(
attr, lr, param, grad, rows, param_out, attr);
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestKernelNCHW16CMulNCTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
const int n = 3, c = 16 * 4, h = 10, w = 10; const int n = 3, c = 16 * 4, h = 10, w = 10;
auto ref = jit::GetRefer<KT, jit::NCHW16CMulNCTuples<T>>(); auto ref = jit::GetRefer<KT, jit::NCHW16CMulNCTuples<T>>();
...@@ -758,7 +860,7 @@ void TestNCHW16CMulNCKernel() { ...@@ -758,7 +860,7 @@ void TestNCHW16CMulNCKernel() {
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestLayerNormKernel() { void TestKernelLayerNormTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
const T epsilon = 9.99999975e-06; const T epsilon = 9.99999975e-06;
for (int n : {1, 2, 10}) { for (int n : {1, 2, 10}) {
...@@ -797,11 +899,13 @@ void TestLayerNormKernel() { ...@@ -797,11 +899,13 @@ void TestLayerNormKernel() {
} }
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType> template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestCRFDecodingKernel() { void TestKernelCRFDecodingTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
constexpr int state_trans_base_idx = 2; constexpr int state_trans_base_idx = 2;
auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (int seq_len : {1, 11, 17, 50}) { for (int seq_len : {1, 11, 17, 50}) {
for (int tag_num : TestSizes()) { for (int tag_num : test_sizes) {
auto ref = jit::GetRefer<KT, jit::CRFDecodingTuples<T>>(); auto ref = jit::GetRefer<KT, jit::CRFDecodingTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
int x_sz = seq_len * tag_num; int x_sz = seq_len * tag_num;
...@@ -822,143 +926,76 @@ void TestCRFDecodingKernel() { ...@@ -822,143 +926,76 @@ void TestCRFDecodingKernel() {
} }
} }
// XYZNTuple #define TEST_CPU_KERNEL(test_tuple, kernel_type) \
TEST(JITKernel, kVMul) { TEST(JITKernel, kernel_type) { \
TestXYZNKernel<jit::kVMul, float, CPUPlace>(); TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
TestXYZNKernel<jit::kVMul, double, CPUPlace>(); TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
} }
TEST(JITKernel, kVAdd) {
TestXYZNKernel<jit::kVAdd, float, CPUPlace>();
TestXYZNKernel<jit::kVAdd, double, CPUPlace>();
}
TEST(JITKernel, kVAddRelu) {
TestXYZNKernel<jit::kVAddRelu, float, CPUPlace>();
TestXYZNKernel<jit::kVAddRelu, double, CPUPlace>();
}
TEST(JITKernel, kVSub) {
TestXYZNKernel<jit::kVSub, float, CPUPlace>();
TestXYZNKernel<jit::kVSub, double, CPUPlace>();
}
// AXYNTuples
TEST(JITKernel, kVScal) {
TestAXYNKernel<jit::kVScal, float, CPUPlace>();
TestAXYNKernel<jit::kVScal, double, CPUPlace>();
}
TEST(JITKernel, kVAddBias) {
TestAXYNKernel<jit::kVAddBias, float, CPUPlace>();
TestAXYNKernel<jit::kVAddBias, double, CPUPlace>();
}
// XRNTuples
TEST(JITKernel, kHMax) {
TestXRNKernel<jit::kHMax, float, CPUPlace>();
TestXRNKernel<jit::kHMax, double, CPUPlace>();
}
TEST(JITKernel, kHSum) {
TestXRNKernel<jit::kHSum, float, CPUPlace>();
TestXRNKernel<jit::kHSum, double, CPUPlace>();
}
// XYNTuples
TEST(JITKernel, kVRelu) {
TestXYNKernel<jit::kVRelu, float, CPUPlace>();
TestXYNKernel<jit::kVRelu, double, CPUPlace>();
}
TEST(JITKernel, kVIdentity) {
TestXYNKernel<jit::kVIdentity, float, CPUPlace>();
TestXYNKernel<jit::kVIdentity, double, CPUPlace>();
}
TEST(JITKernel, kVSquare) {
TestXYNKernel<jit::kVSquare, float, CPUPlace>();
TestXYNKernel<jit::kVSquare, double, CPUPlace>();
}
TEST(JITKernel, kVExp) {
TestXYNKernel<jit::kVExp, float, CPUPlace>();
TestXYNKernel<jit::kVExp, double, CPUPlace>();
}
TEST(JITKernel, kVSigmoid) {
TestXYNKernel<jit::kVSigmoid, float, CPUPlace>();
TestXYNKernel<jit::kVSigmoid, double, CPUPlace>();
}
TEST(JITKernel, kVTanh) { TEST_CPU_KERNEL(XYZNTuples, kVMul);
TestXYNKernel<jit::kVTanh, float, CPUPlace>(); TEST_CPU_KERNEL(XYZNTuples, kVAdd);
TestXYNKernel<jit::kVTanh, double, CPUPlace>(); TEST_CPU_KERNEL(XYZNTuples, kVAddRelu);
} TEST_CPU_KERNEL(XYZNTuples, kVSub);
// LSTM TEST_CPU_KERNEL(AXYNTuples, kVScal);
TEST(JITKernel, kLSTMCtHt) { TEST_CPU_KERNEL(AXYNTuples, kVAddBias);
TestLSTMKernel<jit::kLSTMCtHt, float, CPUPlace>();
TestLSTMKernel<jit::kLSTMCtHt, double, CPUPlace>();
}
TEST(JITKernel, kLSTMC1H1) { TEST_CPU_KERNEL(XRNTuples, kHMax);
TestLSTMKernel<jit::kLSTMC1H1, float, CPUPlace>(); TEST_CPU_KERNEL(XRNTuples, kHSum);
TestLSTMKernel<jit::kLSTMC1H1, double, CPUPlace>();
}
// GRU TEST_CPU_KERNEL(XYNTuples, kVRelu);
TEST(JITKernel, kGRUH1) { TEST_CPU_KERNEL(XYNTuples, kVIdentity);
TestGRUKernel<jit::kGRUH1, float, CPUPlace>(); TEST_CPU_KERNEL(XYNTuples, kVSquare);
TestGRUKernel<jit::kGRUH1, double, CPUPlace>(); TEST_CPU_KERNEL(XYNTuples, kVExp);
} TEST_CPU_KERNEL(XYNTuples, kVSigmoid);
TEST_CPU_KERNEL(XYNTuples, kVTanh);
TEST(JITKernel, kGRUHtPart1) { TEST_CPU_KERNEL(LSTMTuples, kLSTMCtHt);
TestGRUKernel<jit::kGRUHtPart1, float, CPUPlace>(); TEST_CPU_KERNEL(LSTMTuples, kLSTMC1H1);
TestGRUKernel<jit::kGRUHtPart1, double, CPUPlace>();
}
TEST(JITKernel, kGRUHtPart2) { TEST_CPU_KERNEL(GRUTuples, kGRUH1);
TestGRUKernel<jit::kGRUHtPart2, float, CPUPlace>(); TEST_CPU_KERNEL(GRUTuples, kGRUHtPart1);
TestGRUKernel<jit::kGRUHtPart2, double, CPUPlace>(); TEST_CPU_KERNEL(GRUTuples, kGRUHtPart2);
}
TEST(JITKernel, kSeqPool) { TEST_CPU_KERNEL(NCHW16CMulNCTuples, kNCHW16CMulNC);
TestSeqPoolKernel<jit::kSeqPool, float, CPUPlace>();
TestSeqPoolKernel<jit::kSeqPool, double, CPUPlace>();
}
TEST(JITKernel, kMatMul) { TEST_CPU_KERNEL(SeqPoolTuples, kSeqPool);
TestMatMulKernel<jit::kMatMul, float, CPUPlace>(); TEST_CPU_KERNEL(MatMulTuples, kMatMul);
TestMatMulKernel<jit::kMatMul, double, CPUPlace>(); TEST_CPU_KERNEL(SoftmaxTuples, kSoftmax);
} TEST_CPU_KERNEL(EmbSeqPoolTuples, kEmbSeqPool);
TEST_CPU_KERNEL(SgdTuples, kSgd);
TEST_CPU_KERNEL(LayerNormTuples, kLayerNorm);
TEST_CPU_KERNEL(CRFDecodingTuples, kCRFDecoding);
TEST(JITKernel, kSoftmax) { TEST(JITKernel_key, lstm) {
TestSoftmaxKernel<jit::kSoftmax, float, CPUPlace>(); jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
TestSoftmaxKernel<jit::kSoftmax, double, CPUPlace>(); jit::lstm_attr_t attr2(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
} jit::lstm_attr_t attr3(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
jit::lstm_attr_t attr4(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh);
TEST(JITKernel, kEmbSeqPool) { auto key1 = jit::JitCodeKey<jit::lstm_attr_t>(attr1);
TestEmbSeqPoolKernel<jit::kEmbSeqPool, float, CPUPlace>(); auto key2 = jit::JitCodeKey<jit::lstm_attr_t>(attr2);
TestEmbSeqPoolKernel<jit::kEmbSeqPool, double, CPUPlace>(); auto key3 = jit::JitCodeKey<jit::lstm_attr_t>(attr3);
} auto key4 = jit::JitCodeKey<jit::lstm_attr_t>(attr4);
TEST(JITKernel, kNCHW16CMulNC) { EXPECT_TRUE(key1 != key2);
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float, CPUPlace>(); EXPECT_TRUE(key2 == key3);
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, double, CPUPlace>(); EXPECT_TRUE(key3 != key4);
} }
TEST(JITKernel, kLayerNorm) { TEST(JITKernel_key, gru) {
TestLayerNormKernel<jit::kLayerNorm, float, paddle::platform::CPUPlace>(); jit::gru_attr_t attr1(8, jit::kVSigmoid, jit::kVTanh);
TestLayerNormKernel<jit::kLayerNorm, double, paddle::platform::CPUPlace>(); jit::gru_attr_t attr2(9, jit::kVSigmoid, jit::kVTanh);
} jit::gru_attr_t attr3(9, jit::kVSigmoid, jit::kVTanh);
jit::gru_attr_t attr4(9, jit::kVSigmoid, jit::kVIdentity);
TEST(JITKernel, kCRFDecoding) { auto key1 = jit::JitCodeKey<jit::gru_attr_t>(attr1);
TestCRFDecodingKernel<jit::kCRFDecoding, float, paddle::platform::CPUPlace>(); auto key2 = jit::JitCodeKey<jit::gru_attr_t>(attr2);
TestCRFDecodingKernel<jit::kCRFDecoding, double, auto key3 = jit::JitCodeKey<jit::gru_attr_t>(attr3);
paddle::platform::CPUPlace>(); auto key4 = jit::JitCodeKey<jit::gru_attr_t>(attr4);
}
TEST(JITKernel, pool) { EXPECT_TRUE(key1 != key2);
// TODO(TJ): add some test EXPECT_TRUE(key2 == key3);
EXPECT_TRUE(key3 != key4);
} }
// TODO(TJ): add more test about key and pool
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/jit/kernels.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -32,53 +33,57 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -32,53 +33,57 @@ class SGDOpKernel : public framework::OpKernel<T> {
if (param_var->IsType<framework::LoDTensor>()) { if (param_var->IsType<framework::LoDTensor>()) {
const auto *param = ctx.Input<framework::Tensor>("Param"); const auto *param = ctx.Input<framework::Tensor>("Param");
auto *param_out = ctx.Output<framework::Tensor>("ParamOut"); auto *param_out = ctx.Output<framework::Tensor>("ParamOut");
// Actually, all tensors are LoDTensor except SelectedRows. // Actually, all tensors are LoDTensor except SelectedRows.
if (grad_var->IsType<framework::LoDTensor>()) { if (grad_var->IsType<framework::LoDTensor>()) {
param_out->mutable_data<T>(ctx.GetPlace());
const auto *grad = ctx.Input<framework::Tensor>("Grad"); const auto *grad = ctx.Input<framework::Tensor>("Grad");
auto sz = param_out->numel();
auto p = framework::EigenVector<T>::Flatten(*param); PADDLE_ENFORCE_EQ(param->numel(), sz);
auto g = framework::EigenVector<T>::Flatten(*grad); PADDLE_ENFORCE_EQ(grad->numel(), sz);
auto o = framework::EigenVector<T>::Flatten(*param_out);
auto *lr = learning_rate->data<T>(); jit::sgd_attr_t attr(1, sz, 1, sz, 1);
const T *lr = learning_rate->data<T>();
o = p - lr[0] * g; const T *param_data = param->data<T>();
const T *grad_data = grad->data<T>();
int64_t rows_idx = 0;
T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
auto sgd =
jit::Get<jit::kSgd, jit::SgdTuples<T>, platform::CPUPlace>(attr);
sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr);
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
// TODO(qijun): In Sparse SGD operator, in-place update is enforced. // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
// This manual optimization brings difficulty to track data dependency. // This manual optimization brings difficulty to track data dependency.
// It's better to find a more elegant solution. // It's better to find a more elegant solution.
PADDLE_ENFORCE_EQ(param, param_out); PADDLE_ENFORCE_EQ(param, param_out);
const auto *grad = ctx.Input<framework::SelectedRows>("Grad"); const auto *grad = ctx.Input<framework::SelectedRows>("Grad");
auto &grad_rows = grad->rows();
// for distributed training, a sparse var may be empty, // for distributed training, a sparse var may be empty,
// just skip updating. // just skip updating.
if (grad->rows().size() == 0) { if (grad_rows.size() == 0) {
return; return;
} }
auto grad_height = grad->height();
auto out_dims = param_out->dims(); auto out_dims = param_out->dims();
PADDLE_ENFORCE_EQ(grad_height, out_dims[0]); PADDLE_ENFORCE_EQ(grad->height(), out_dims[0]);
auto &grad_value = grad->value(); auto &grad_value = grad->value();
auto &grad_rows = grad->rows(); const T *param_data = param->data<T>();
const T *grad_data = grad_value.data<T>();
size_t grad_row_numel = grad_value.numel() / grad_rows.size(); const T *lr = learning_rate->data<T>();
PADDLE_ENFORCE_EQ(static_cast<int64_t>(grad_row_numel), const int64_t *rows_data = grad_rows.data();
param_out->numel() / grad_height); T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
auto *grad_data = grad_value.data<T>(); jit::sgd_attr_t attr;
auto *out_data = param_out->data<T>(); attr.param_height = out_dims[0];
auto *lr = learning_rate->data<T>(); attr.param_width = param_out->numel() / attr.param_height;
for (size_t i = 0; i < grad_rows.size(); i++) { attr.grad_height = grad_rows.size(); // note: it is not grad->height()
PADDLE_ENFORCE(grad_rows[i] < grad_height, attr.grad_width = grad_value.numel() / attr.grad_height;
"Input rows index should less than height"); attr.selected_rows_size = grad_rows.size();
for (size_t j = 0; j < grad_row_numel; j++) { PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width);
out_data[grad_rows[i] * grad_row_numel + j] -=
lr[0] * grad_data[i * grad_row_numel + j]; auto sgd =
} jit::Get<jit::kSgd, jit::SgdTuples<T>, platform::CPUPlace>(attr);
} sgd(lr, param_data, grad_data, rows_data, out_data, &attr);
} else { } else {
PADDLE_THROW("Unsupported Variable Type of Grad"); PADDLE_THROW("Unsupported Variable Type of Grad");
} }
......
...@@ -24,17 +24,28 @@ from op_test import OpTest ...@@ -24,17 +24,28 @@ from op_test import OpTest
class TestSGDOp(OpTest): class TestSGDOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "sgd" self.op_type = "sgd"
w = np.random.random((102, 105)).astype("float32") self.conf()
g = np.random.random((102, 105)).astype("float32") w = np.random.random((self.h, self.w)).astype("float32")
g = np.random.random((self.h, self.w)).astype("float32")
lr = np.array([0.1]).astype("float32") lr = np.array([0.1]).astype("float32")
self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr} self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr}
self.outputs = {'ParamOut': w - lr * g} self.outputs = {'ParamOut': w - lr * g}
def conf(self):
self.h = 102
self.w = 105
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestSGDOpCase8X(TestSGDOp):
def conf(self):
self.h = 10
self.w = 64
class TestSparseSGDOp(unittest.TestCase): class TestSparseSGDOp(unittest.TestCase):
def check_with_place(self, place): def check_with_place(self, place):
scope = core.Scope() scope = core.Scope()
...@@ -42,12 +53,12 @@ class TestSparseSGDOp(unittest.TestCase): ...@@ -42,12 +53,12 @@ class TestSparseSGDOp(unittest.TestCase):
# create and initialize Grad Variable # create and initialize Grad Variable
height = 10 height = 10
rows = [0, 4, 7] rows = [0, 4, 7]
row_numel = 12 self.conf()
grad_selected_rows = scope.var('Grad').get_selected_rows() grad_selected_rows = scope.var('Grad').get_selected_rows()
grad_selected_rows.set_height(height) grad_selected_rows.set_height(height)
grad_selected_rows.set_rows(rows) grad_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), row_numel)).astype("float32") np_array = np.ones((len(rows), self.row_numel)).astype("float32")
np_array[0, 0] = 2.0 np_array[0, 0] = 2.0
np_array[2, 8] = 4.0 np_array[2, 8] = 4.0
...@@ -56,7 +67,7 @@ class TestSparseSGDOp(unittest.TestCase): ...@@ -56,7 +67,7 @@ class TestSparseSGDOp(unittest.TestCase):
# create and initialize Param Variable # create and initialize Param Variable
param = scope.var('Param').get_tensor() param = scope.var('Param').get_tensor()
param_array = np.full((height, row_numel), 5.0).astype("float32") param_array = np.full((height, self.row_numel), 5.0).astype("float32")
param.set(param_array, place) param.set(param_array, place)
# create and initialize LeraningRate Variable # create and initialize LeraningRate Variable
...@@ -98,6 +109,14 @@ class TestSparseSGDOp(unittest.TestCase): ...@@ -98,6 +109,14 @@ class TestSparseSGDOp(unittest.TestCase):
for place in places: for place in places:
self.check_with_place(place) self.check_with_place(place)
def conf(self):
self.row_numel = 12
class TestSparseSGDOpCase8X(TestSparseSGDOp):
def conf(self):
self.row_numel = 16
class TestSGDOpOptimizeSelectedRows(unittest.TestCase): class TestSGDOpOptimizeSelectedRows(unittest.TestCase):
def check_with_place(self, place): def check_with_place(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册