提交 50945685 编写于 作者: T tensor-tang

add hmax, hsum jitcode

test=develop
上级 81177258
......@@ -28,3 +28,5 @@ USE_JITKERNEL_GEN(kGRUHtPart1)
USE_JITKERNEL_GEN(kGRUHtPart2)
USE_JITKERNEL_GEN(kNCHW16CMulNC)
USE_JITKERNEL_GEN(kSeqPool)
USE_JITKERNEL_GEN(kHMax)
USE_JITKERNEL_GEN(kHSum)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/hopv.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void HOPVJitCode::genCode() {
const int num_blocks = num_ / YMM_FLOAT_BLOCK;
int offset = 0;
if (num_blocks > 0) {
// load one firstly
vmovups(ymm_tmp, ptr[param_src]);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
for (int i = 1; i < num_blocks; ++i) {
vmovups(ymm_src, ptr[param_src + offset]);
process(ymm_tmp, ymm_src, ymm_tmp);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
vextractf128(xmm_dst, ymm_tmp, 1);
process(xmm_dst, xmm_dst, xmm_tmp);
} else {
if (type_ == operand_type::MAX) {
vbroadcastss(ymm_dst, ptr[param_src]);
} else if (type_ == operand_type::ADD) {
vxorps(ymm_dst, ymm_dst, ymm_dst);
}
}
int rest = num_ % YMM_FLOAT_BLOCK;
if (rest >= 4) {
vmovups(xmm_src, ptr[param_src + offset]);
offset += sizeof(float) * 4;
rest -= 4;
process(xmm_dst, xmm_dst, xmm_src);
}
vpermilps(xmm_tmp, xmm_dst, 16 + 8 + 3);
process(xmm_dst, xmm_dst, xmm_tmp);
if (rest >= 2) {
vmovq(xmm_src, ptr[param_src + offset]);
offset += sizeof(float) * 2;
rest -= 2;
process(xmm_dst, xmm_dst, xmm_src);
}
vpermilps(xmm_tmp, xmm_dst, 1);
process(xmm_dst, xmm_dst, xmm_tmp);
if (rest >= 1) {
vmovss(xmm_src, ptr[param_src + offset]);
process(xmm_dst, xmm_dst, xmm_src);
}
vmovss(ptr[param_dst], xmm_dst);
ret();
}
#define DECLARE_HOP_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_HOP_CREATOR(HMax);
DECLARE_HOP_CREATOR(HSum);
#undef DECLARE_HOP_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kHMax, gen::HMaxCreator);
REGISTER_JITKERNEL_GEN(kHSum, gen::HSumCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
// horizontal operand vector
class HOPVJitCode : public JitCode {
public:
explicit HOPVJitCode(int d, operand_type type, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d), type_(type) {
if (!(type_ == operand_type::MAX || type_ == operand_type::ADD)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
}
virtual const char* name() const {
std::string base = "VXXJitCode";
if (type_ == operand_type::MAX) {
base += "_MAX";
} else {
base += "_SUM";
}
return base.c_str();
}
void genCode() override;
protected:
template <typename JMM>
void process(JMM& dst, JMM& src1, JMM& src2) { // NOLINT
if (type_ == operand_type::MAX) {
vmaxps(dst, src1, src2);
} else if (type_ == operand_type::ADD) {
vaddps(dst, src1, src2);
}
}
private:
int num_;
operand_type type_;
reg64_t param_src{abi_param1};
reg64_t param_dst{abi_param2};
reg64_t param_attr{abi_param3};
ymm_t ymm_tmp = ymm_t(0);
ymm_t ymm_src = ymm_t(1);
ymm_t ymm_dst = ymm_t(2);
xmm_t xmm_tmp = xmm_t(0);
xmm_t xmm_src = xmm_t(1);
xmm_t xmm_dst = xmm_t(2);
};
#define DECLARE_HOP_JITCODE(name, op_type) \
class name##JitCode : public HOPVJitCode { \
public: \
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
: HOPVJitCode(d, op_type, code_size, code_ptr) {} \
};
DECLARE_HOP_JITCODE(HMax, operand_type::MAX);
DECLARE_HOP_JITCODE(HSum, operand_type::ADD);
#undef DECLARE_HOP_JITCODE
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
......@@ -47,6 +47,7 @@ using Label = Xbyak::Label;
typedef enum {
MUL = 0,
MAX,
ADD,
SUB,
RELU,
......
......@@ -383,16 +383,19 @@ void TestAXYNKernel() {
template <jit::KernelType KT, typename T, typename PlaceType>
void TestXRNKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
auto last_acc = acc;
acc = 1e-4;
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XRNTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d);
RandomVec<T>(d, x.data());
RandomVec<T>(d, x.data(), -2.f, 2.f);
T ref_res;
ref(x.data(), &ref_res, d);
TestAllImpls<KT, jit::XRNTuples<T>, PlaceType, std::vector<T>, T>(d, x,
ref_res);
}
acc = last_acc;
}
template <jit::KernelType KT, typename T, typename PlaceType>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册