提交 d59f7335 编写于 作者: T tensor-tang

refine softmax and use with cache

test=develop
上级 7383eefd
...@@ -187,6 +187,9 @@ void BenchAXYNKernel() { ...@@ -187,6 +187,9 @@ void BenchAXYNKernel() {
RandomVec<T>(d, x_data); RandomVec<T>(d, x_data);
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), y_data, BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), y_data,
d); d);
// test inplace
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data<T>(), x_data,
d);
} }
} }
......
...@@ -81,9 +81,7 @@ void VActJitCode::genCode() { ...@@ -81,9 +81,7 @@ void VActJitCode::genCode() {
#define DECLARE_ACT_CREATOR(name) \ #define DECLARE_ACT_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \ class name##Creator : public JitCodeCreator<int> { \
public: \ public: \
bool UseMe(const int& attr) const override { \ bool UseMe(const int& attr) const override; \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override; \ size_t CodeSize(const int& d) const override; \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \ std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \ return make_unique<name##JitCode>(attr, CodeSize(attr)); \
...@@ -98,6 +96,30 @@ DECLARE_ACT_CREATOR(VSigmoid); ...@@ -98,6 +96,30 @@ DECLARE_ACT_CREATOR(VSigmoid);
DECLARE_ACT_CREATOR(VTanh); DECLARE_ACT_CREATOR(VTanh);
// TODO(TJ): tuning use me // TODO(TJ): tuning use me
bool VReluCreator::UseMe(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VSquareCreator::UseMe(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VIdentityCreator::UseMe(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VExpCreator::UseMe(const int& d) const {
return platform::MayIUse(platform::avx) && d < 32;
}
bool VSigmoidCreator::UseMe(const int& d) const {
return platform::MayIUse(platform::avx);
}
bool VTanhCreator::UseMe(const int& d) const {
return platform::MayIUse(platform::avx);
}
size_t VReluCreator::CodeSize(const int& d) const { size_t VReluCreator::CodeSize(const int& d) const {
return 96 /* init size */ + return 96 /* init size */ +
(d / YMM_FLOAT_BLOCK + 3) * 4 /* instructions */ * (d / YMM_FLOAT_BLOCK + 3) * 4 /* instructions */ *
......
...@@ -118,6 +118,28 @@ typename KernelTuples::func_type Get( ...@@ -118,6 +118,28 @@ typename KernelTuples::func_type Get(
return GetRefer<KT, KernelTuples>(); return GetRefer<KT, KernelTuples>();
} }
template <KernelType KT, typename KernelTuples>
class KernelFuncsCache {
public:
KernelFuncsCache() = default;
static KernelFuncsCache& Instance() {
static thread_local KernelFuncsCache<KT, KernelTuples> g_func_cache;
return g_func_cache;
}
bool Has(int key) const { return funcs_.find(key) != funcs_.end(); }
typename KernelTuples::func_type At(int key) { return funcs_.at(key); }
void Insert(int key, typename KernelTuples::func_type func) {
funcs_.emplace(key, func);
}
private:
std::unordered_map<int, typename KernelTuples::func_type> funcs_;
DISABLE_COPY_AND_ASSIGN(KernelFuncsCache);
};
const char* to_string(KernelType kt); const char* to_string(KernelType kt);
const char* to_string(SeqPoolType kt); const char* to_string(SeqPoolType kt);
......
...@@ -49,12 +49,50 @@ void VTanh(const T* x, T* y, int n) { ...@@ -49,12 +49,50 @@ void VTanh(const T* x, T* y, int n) {
} }
void Softmax(const T* x, T* y, int n, int bs) { void Softmax(const T* x, T* y, int n, int bs) {
auto compute_hmax = Get<kHMax, XRNTuples<T>, platform::CPUPlace>(n); typename XRNTuples<T>::func_type compute_hmax{nullptr};
auto compute_hsum = Get<kHSum, XRNTuples<T>, platform::CPUPlace>(n); typename XRNTuples<T>::func_type compute_hsum{nullptr};
auto compute_vscal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n); typename AXYNTuples<T>::func_type compute_vscal{nullptr};
auto compute_vaddbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n); typename AXYNTuples<T>::func_type compute_vaddbias{nullptr};
auto compute_vexp = typename XYNTuples<T>::func_type compute_vexp{nullptr};
Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
if (!KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Has(n)) {
compute_hmax = Get<kHMax, XRNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().Insert(n, compute_hmax);
} else {
compute_hmax = KernelFuncsCache<kHMax, XRNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Has(n)) {
compute_hsum = Get<kHSum, XRNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().Insert(n, compute_hsum);
} else {
compute_hsum = KernelFuncsCache<kHSum, XRNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Has(n)) {
compute_vscal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().Insert(n,
compute_vscal);
} else {
compute_vscal = KernelFuncsCache<kVScal, AXYNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Has(n)) {
compute_vaddbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().Insert(
n, compute_vaddbias);
} else {
compute_vaddbias =
KernelFuncsCache<kVAddBias, AXYNTuples<T>>::Instance().At(n);
}
if (!KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Has(n)) {
compute_vexp = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().Insert(n, compute_vexp);
} else {
compute_vexp = KernelFuncsCache<kVExp, XYNTuples<T>>::Instance().At(n);
}
for (int i = 0; i < bs; ++i) { for (int i = 0; i < bs; ++i) {
T scalar; T scalar;
compute_hmax(x, &scalar, n); compute_hmax(x, &scalar, n);
......
...@@ -179,7 +179,8 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const { ...@@ -179,7 +179,8 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
template <> template <>
bool SoftmaxKernel<float>::UseMe(const int& d) const { bool SoftmaxKernel<float>::UseMe(const int& d) const {
return true; // tuned on avx2
return platform::MayIUse(platform::avx) && d < 60;
} }
#define AWALYS_USE_ME_WITH_DOUBLE(func) \ #define AWALYS_USE_ME_WITH_DOUBLE(func) \
......
...@@ -53,7 +53,7 @@ math_library(sequence2batch) ...@@ -53,7 +53,7 @@ math_library(sequence2batch)
math_library(sequence_padding) math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function jit_kernel_helper) math_library(sequence_pooling DEPS math_function jit_kernel_helper)
math_library(sequence_scale) math_library(sequence_scale)
math_library(softmax DEPS math_function) math_library(softmax DEPS math_function jit_kernel_helper)
math_library(beam_search DEPS math_function) math_library(beam_search DEPS math_function)
math_library(matrix_bit_code) math_library(matrix_bit_code)
......
...@@ -16,8 +16,8 @@ limitations under the License. */ ...@@ -16,8 +16,8 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -81,28 +81,10 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> { ...@@ -81,28 +81,10 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
const int kBatchDim = 0; const int kBatchDim = 0;
const int kClassDim = 1; const int kClassDim = 1;
// 2D data. Batch x C // 2D data. Batch x C
const int batch_size = in_dims[kBatchDim]; auto compute_softmax =
const int num_classes = in_dims[kClassDim]; jit::Get<jit::kSoftmax, jit::SoftmaxTuples<float>, platform::CPUPlace>(
std::vector<float> entities(batch_size); in_dims[kClassDim]);
auto blas = math::GetBlas<DeviceContext, float>(context); compute_softmax(in_data, out_data, in_dims[kClassDim], in_dims[kBatchDim]);
for (int n = 0; n < batch_size; ++n) {
entities[n] = in_data[n * num_classes];
for (int c = 1; c < num_classes; ++c) {
entities[n] = in_data[n * num_classes + c] > entities[n]
? in_data[n * num_classes + c]
: entities[n];
}
for (int c = 0; c < num_classes; ++c) {
out_data[n * num_classes + c] =
in_data[n * num_classes + c] - entities[n];
}
}
blas.VEXP(num_classes * batch_size, out_data, out_data);
for (int n = 0; n < batch_size; ++n) {
auto sum = blas.ASUM(num_classes, &out_data[n * num_classes], 1);
blas.SCAL(num_classes, 1.0f / sum, &out_data[n * num_classes]);
}
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册