未验证 提交 36ed83d2 编写于 作者: G GaoWei8 提交者: GitHub

Refine PADDLE_ENFORCE (#27360)

* refine PADDLE_ENFORCE
上级 effd51b6
...@@ -47,8 +47,8 @@ void OpTester::Init(const OpTesterConfig &config) { ...@@ -47,8 +47,8 @@ void OpTester::Init(const OpTesterConfig &config) {
CreateInputVarDesc(); CreateInputVarDesc();
CreateOutputVarDesc(); CreateOutputVarDesc();
} else { } else {
PADDLE_THROW(platform::errors::NotFound("Operator '%s' is not registered.", PADDLE_THROW(platform::errors::NotFound(
config_.op_type)); "Operator '%s' is not registered in OpTester.", config_.op_type));
} }
if (config_.device_id >= 0) { if (config_.device_id >= 0) {
...@@ -81,7 +81,8 @@ void OpTester::Run() { ...@@ -81,7 +81,8 @@ void OpTester::Run() {
platform::EnableProfiler(platform::ProfilerState::kAll); platform::EnableProfiler(platform::ProfilerState::kAll);
platform::SetDeviceId(config_.device_id); platform::SetDeviceId(config_.device_id);
#else #else
PADDLE_THROW("'CUDAPlace' is not supported in CPU only device."); PADDLE_THROW(platform::errors::PermissionDenied(
"'CUDAPlace' is not supported in CPU only device."));
#endif #endif
} }
...@@ -162,7 +163,8 @@ framework::proto::VarType::Type OpTester::TransToVarType(std::string str) { ...@@ -162,7 +163,8 @@ framework::proto::VarType::Type OpTester::TransToVarType(std::string str) {
} else if (str == "fp64") { } else if (str == "fp64") {
return framework::proto::VarType::FP64; return framework::proto::VarType::FP64;
} else { } else {
PADDLE_THROW("Unsupported dtype %s.", str.c_str()); PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported dtype %s in OpTester.", str.c_str()));
} }
} }
...@@ -233,8 +235,8 @@ void OpTester::CreateOpDesc() { ...@@ -233,8 +235,8 @@ void OpTester::CreateOpDesc() {
case framework::proto::AttrType::INTS: case framework::proto::AttrType::INTS:
case framework::proto::AttrType::FLOATS: case framework::proto::AttrType::FLOATS:
case framework::proto::AttrType::STRINGS: case framework::proto::AttrType::STRINGS:
PADDLE_THROW( PADDLE_THROW(platform::errors::Unimplemented(
platform::errors::Unimplemented("Not supported STRINGS type yet.")); "Unsupported STRINGS type in OpTester yet."));
break; break;
case framework::proto::AttrType::LONG: { case framework::proto::AttrType::LONG: {
int64_t value = StringTo<int64_t>(value_str); int64_t value = StringTo<int64_t>(value_str);
...@@ -242,7 +244,8 @@ void OpTester::CreateOpDesc() { ...@@ -242,7 +244,8 @@ void OpTester::CreateOpDesc() {
} break; } break;
case framework::proto::AttrType::LONGS: case framework::proto::AttrType::LONGS:
default: default:
PADDLE_THROW("Unsupport attr type %d", type); PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport attr type %d in OpTester.", type));
} }
} }
} }
...@@ -299,7 +302,8 @@ void OpTester::SetupTensor(framework::LoDTensor *tensor, ...@@ -299,7 +302,8 @@ void OpTester::SetupTensor(framework::LoDTensor *tensor,
} }
is.close(); is.close();
} else { } else {
PADDLE_THROW("Unsupported initializer %s.", initializer.c_str()); PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported initializer %s in OpTester.", initializer.c_str()));
} }
if (!platform::is_cpu_place(place_)) { if (!platform::is_cpu_place(place_)) {
...@@ -351,7 +355,8 @@ void OpTester::CreateVariables(framework::Scope *scope) { ...@@ -351,7 +355,8 @@ void OpTester::CreateVariables(framework::Scope *scope) {
static_cast<double>(1.0), item.second.initializer, static_cast<double>(1.0), item.second.initializer,
item.second.filename); item.second.filename);
} else { } else {
PADDLE_THROW("Unsupported dtype %d.", data_type); PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported dtype %d in OpTester.", data_type));
} }
VLOG(3) << "Set lod for tensor " << var_name; VLOG(3) << "Set lod for tensor " << var_name;
...@@ -473,7 +478,8 @@ std::string OpTester::DebugString() { ...@@ -473,7 +478,8 @@ std::string OpTester::DebugString() {
<< "\n"; << "\n";
} break; } break;
default: default:
PADDLE_THROW("Unsupport attr type %d", attr_type); PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport attr type %d in OpTester.", attr_type));
} }
ss << GenSpaces(--count) << "}\n"; ss << GenSpaces(--count) << "}\n";
} }
...@@ -484,8 +490,10 @@ std::string OpTester::DebugString() { ...@@ -484,8 +490,10 @@ std::string OpTester::DebugString() {
TEST(op_tester, base) { TEST(op_tester, base) {
if (!FLAGS_op_config_list.empty()) { if (!FLAGS_op_config_list.empty()) {
std::ifstream fin(FLAGS_op_config_list, std::ios::in | std::ios::binary); std::ifstream fin(FLAGS_op_config_list, std::ios::in | std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", PADDLE_ENFORCE_EQ(
FLAGS_op_config_list.c_str()); static_cast<bool>(fin), true,
platform::errors::InvalidArgument("OpTester cannot open file %s",
FLAGS_op_config_list.c_str()));
std::vector<OpTesterConfig> op_configs; std::vector<OpTesterConfig> op_configs;
while (!fin.eof()) { while (!fin.eof()) {
VLOG(4) << "Reading config " << op_configs.size() << "..."; VLOG(4) << "Reading config " << op_configs.size() << "...";
......
...@@ -78,7 +78,8 @@ void OpInputConfig::ParseDType(std::istream& is) { ...@@ -78,7 +78,8 @@ void OpInputConfig::ParseDType(std::istream& is) {
} else if (dtype_str == "fp64" || dtype_str == "double") { } else if (dtype_str == "fp64" || dtype_str == "double") {
dtype = "fp64"; dtype = "fp64";
} else { } else {
PADDLE_THROW("Unsupported dtype %s", dtype_str.c_str()); PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported dtype %s in OpInputConfig.", dtype_str.c_str()));
} }
VLOG(4) << "dtype of input " << name << " is: " << dtype; VLOG(4) << "dtype of input " << name << " is: " << dtype;
} }
...@@ -91,7 +92,9 @@ void OpInputConfig::ParseInitializer(std::istream& is) { ...@@ -91,7 +92,9 @@ void OpInputConfig::ParseInitializer(std::istream& is) {
const std::vector<std::string> supported_initializers = {"random", "natural", const std::vector<std::string> supported_initializers = {"random", "natural",
"zeros", "file"}; "zeros", "file"};
if (!Has(supported_initializers, initializer_str)) { if (!Has(supported_initializers, initializer_str)) {
PADDLE_THROW("Unsupported initializer %s", initializer_str.c_str()); PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported initializer %s in OpInputConfig.",
initializer_str.c_str()));
} }
initializer = initializer_str; initializer = initializer_str;
...@@ -126,7 +129,12 @@ void OpInputConfig::ParseLoD(std::istream& is) { ...@@ -126,7 +129,12 @@ void OpInputConfig::ParseLoD(std::istream& is) {
} }
} }
EraseEndSep(&lod_str); EraseEndSep(&lod_str);
PADDLE_ENFORCE_GE(lod_str.length(), 4U); PADDLE_ENFORCE_GE(
lod_str.length(), 4U,
platform::errors::InvalidArgument(
"The length of lod string should be "
"equal to or larger than 4. But length of lod string is %zu.",
lod_str.length()));
VLOG(4) << "lod: " << lod_str << ", length: " << lod_str.length(); VLOG(4) << "lod: " << lod_str << ", length: " << lod_str.length();
// Parse the lod_str // Parse the lod_str
...@@ -153,8 +161,10 @@ void OpInputConfig::ParseLoD(std::istream& is) { ...@@ -153,8 +161,10 @@ void OpInputConfig::ParseLoD(std::istream& is) {
OpTesterConfig::OpTesterConfig(const std::string& filename) { OpTesterConfig::OpTesterConfig(const std::string& filename) {
std::ifstream fin(filename, std::ios::in | std::ios::binary); std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s", PADDLE_ENFORCE_EQ(
filename.c_str()); static_cast<bool>(fin), true,
platform::errors::InvalidArgument("OpTesterConfig cannot open file %s.",
filename.c_str()));
Init(fin); Init(fin);
} }
......
...@@ -136,7 +136,6 @@ void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) { ...@@ -136,7 +136,6 @@ void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
} }
using Tensor = paddle::framework::Tensor; using Tensor = paddle::framework::Tensor;
template <typename KernelTuple, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void BenchKernelXYZN() { void BenchKernelXYZN() {
using T = typename KernelTuple::data_type; using T = typename KernelTuple::data_type;
...@@ -320,8 +319,15 @@ void BenchKernelSgd() { ...@@ -320,8 +319,15 @@ void BenchKernelSgd() {
const T lr = 0.1; const T lr = 0.1;
auto UnDuplicatedRandomVec = [](int n, const int64_t lower, auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
const int64_t upper) -> std::vector<int64_t> { const int64_t upper) -> std::vector<int64_t> {
PADDLE_ENFORCE_LE(static_cast<size_t>(upper - lower), n - 1); PADDLE_ENFORCE_LE(
PADDLE_ENFORCE_GT(n, 0); static_cast<size_t>(upper - lower), n - 1,
paddle::platform::errors::InvalidArgument(
"The range of Sgd (upper - lower) should be equal to or lower "
"than n-1 (Sgd size -1). But upper - lower is %d and n-1 is %d.",
static_cast<size_t>(upper - lower), (n - 1)));
PADDLE_ENFORCE_GT(
n, 0, paddle::platform::errors::InvalidArgument(
"The Sgd size should be larger than 0. But the n is %d.", n));
std::vector<int64_t> all, out; std::vector<int64_t> all, out;
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
all.push_back(i); all.push_back(i);
......
...@@ -132,11 +132,31 @@ class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> { ...@@ -132,11 +132,31 @@ class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
} }
std::unique_ptr<GenBase> CreateJitCode( std::unique_ptr<GenBase> CreateJitCode(
const emb_seq_pool_attr_t& attr) const override { const emb_seq_pool_attr_t& attr) const override {
PADDLE_ENFORCE_GT(attr.table_height, 0); PADDLE_ENFORCE_GT(attr.table_height, 0,
PADDLE_ENFORCE_GT(attr.table_width, 0); platform::errors::InvalidArgument(
PADDLE_ENFORCE_GT(attr.index_height, 0); "The attribute table_height of EmbSeqPool should "
PADDLE_ENFORCE_GT(attr.index_width, 0); "be larger than 0. But it is %d.",
PADDLE_ENFORCE_GT(attr.out_width, 0); attr.table_height));
PADDLE_ENFORCE_GT(attr.table_width, 0,
platform::errors::InvalidArgument(
"The attribute table_width of EmbSeqPool should "
"be larger than 0. But it is %d.",
attr.table_width));
PADDLE_ENFORCE_GT(attr.index_height, 0,
platform::errors::InvalidArgument(
"The attribute index_height of EmbSeqPool should "
"be larger than 0. But it is %d.",
attr.index_height));
PADDLE_ENFORCE_GT(attr.index_width, 0,
platform::errors::InvalidArgument(
"The attribute index_width of EmbSeqPool should "
"be larger than 0. But it is %d.",
attr.index_width));
PADDLE_ENFORCE_GT(attr.out_width, 0,
platform::errors::InvalidArgument(
"The attribute out_width of EmbSeqPool should be "
"larger than 0. But it is %d.",
attr.out_width));
return make_unique<EmbSeqPoolJitCode>(attr, CodeSize(attr)); return make_unique<EmbSeqPoolJitCode>(attr, CodeSize(attr));
} }
}; };
......
...@@ -29,7 +29,11 @@ void MatMulJitCode::genCode() { ...@@ -29,7 +29,11 @@ void MatMulJitCode::genCode() {
preCode(); preCode();
int block, rest; int block, rest;
const auto groups = packed_groups(n_, k_, &block, &rest); const auto groups = packed_groups(n_, k_, &block, &rest);
PADDLE_ENFORCE_GT(groups.front(), 0); PADDLE_ENFORCE_GT(
groups.front(), 0,
platform::errors::InvalidArgument("The number of rest registers should "
"be larger than 0. But it is %d.",
groups.front()));
const int block_len = sizeof(float) * block; const int block_len = sizeof(float) * block;
const int x_reg_idx = (block == ZMM_FLOAT_BLOCK ? 32 : 16) - 1; const int x_reg_idx = (block == ZMM_FLOAT_BLOCK ? 32 : 16) - 1;
...@@ -118,9 +122,21 @@ class MatMulCreator : public JitCodeCreator<matmul_attr_t> { ...@@ -118,9 +122,21 @@ class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
} }
std::unique_ptr<GenBase> CreateJitCode( std::unique_ptr<GenBase> CreateJitCode(
const matmul_attr_t& attr) const override { const matmul_attr_t& attr) const override {
PADDLE_ENFORCE_GT(attr.m, 0); PADDLE_ENFORCE_GT(
PADDLE_ENFORCE_GT(attr.n, 0); attr.m, 0, platform::errors::InvalidArgument(
PADDLE_ENFORCE_GT(attr.k, 0); "The attribute m (first matrix's row) of MatMul should "
"be larger than 0. But it is %d.",
attr.m));
PADDLE_ENFORCE_GT(
attr.n, 0, platform::errors::InvalidArgument(
"The attribute n (first matrix's col) of MatMul should "
"be larger than 0. But it is %d.",
attr.n));
PADDLE_ENFORCE_GT(
attr.k, 0, platform::errors::InvalidArgument(
"The attribute k (second matrix's col) of MatMul should "
"be larger than 0. But it is %d.",
attr.k));
return make_unique<MatMulJitCode>(attr, CodeSize(attr)); return make_unique<MatMulJitCode>(attr, CodeSize(attr));
} }
}; };
......
...@@ -33,7 +33,10 @@ class MatMulJitCode : public JitCode { ...@@ -33,7 +33,10 @@ class MatMulJitCode : public JitCode {
size_t code_size = 256 * 1024, size_t code_size = 256 * 1024,
void* code_ptr = nullptr) void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), m_(attr.m), n_(attr.n), k_(attr.k) { : JitCode(code_size, code_ptr), m_(attr.m), n_(attr.n), k_(attr.k) {
PADDLE_ENFORCE_EQ(m_, 1, "Only support m==1 yet"); PADDLE_ENFORCE_EQ(m_, 1, platform::errors::Unimplemented(
"Jitcode of matmul only support m==1 (first "
"matrix's row) now. But m is %d.",
m_));
this->genCode(); this->genCode();
} }
......
...@@ -70,8 +70,14 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> { ...@@ -70,8 +70,14 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
} }
std::unique_ptr<GenBase> CreateJitCode( std::unique_ptr<GenBase> CreateJitCode(
const seq_pool_attr_t& attr) const override { const seq_pool_attr_t& attr) const override {
PADDLE_ENFORCE_GT(attr.w, 0); PADDLE_ENFORCE_GT(attr.w, 0, platform::errors::InvalidArgument(
PADDLE_ENFORCE_GT(attr.h, 0); "The attribute width of SeqPool should "
"be larger than 0. But it is %d.",
attr.w));
PADDLE_ENFORCE_GT(attr.h, 0, platform::errors::InvalidArgument(
"The attribute height of SeqPool should "
"be larger than 0. But it is %d.",
attr.h));
return make_unique<SeqPoolJitCode>(attr, CodeSize(attr)); return make_unique<SeqPoolJitCode>(attr, CodeSize(attr));
} }
}; };
......
...@@ -127,8 +127,13 @@ class SeqPoolJitCode : public JitCode { ...@@ -127,8 +127,13 @@ class SeqPoolJitCode : public JitCode {
vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]); vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
reg_idx++; reg_idx++;
} }
PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs, PADDLE_ENFORCE_EQ(
"All heights should use same regs"); reg_idx, rest_used_num_regs,
platform::errors::InvalidArgument(
"All heights of SeqPool should use the same number of registers."
"It equals to the numbr of rest registers. But use %d registers "
"and the numbr of rest registers is %d.",
reg_idx, rest_used_num_regs));
for (int i = 0; i < reg_idx; ++i) { for (int i = 0; i < reg_idx; ++i) {
vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs)); vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
} }
......
...@@ -116,9 +116,24 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> { ...@@ -116,9 +116,24 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> {
size_t CodeSize(const sgd_attr_t& attr) const override { return 96 + 32 * 8; } size_t CodeSize(const sgd_attr_t& attr) const override { return 96 + 32 * 8; }
std::unique_ptr<GenBase> CreateJitCode( std::unique_ptr<GenBase> CreateJitCode(
const sgd_attr_t& attr) const override { const sgd_attr_t& attr) const override {
PADDLE_ENFORCE_EQ(attr.param_width, attr.grad_width); PADDLE_ENFORCE_EQ(attr.param_width, attr.grad_width,
PADDLE_ENFORCE_LE(attr.selected_rows_size, attr.grad_height); platform::errors::InvalidArgument(
PADDLE_ENFORCE_GE(attr.selected_rows_size, 0); "The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d.",
attr.param_width, attr.grad_width));
PADDLE_ENFORCE_LE(attr.selected_rows_size, attr.grad_height,
platform::errors::InvalidArgument(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d.",
attr.selected_rows_size, attr.grad_height));
PADDLE_ENFORCE_GE(
attr.selected_rows_size, 0,
platform::errors::InvalidArgument(
"The attribute selected_rows_size of Sgd should be "
"equal to or larger than 0. But selected_rows_size is %d.",
attr.selected_rows_size));
return make_unique<SgdJitCode>(attr, CodeSize(attr)); return make_unique<SgdJitCode>(attr, CodeSize(attr));
} }
}; };
......
...@@ -76,7 +76,11 @@ class VBroadcastCreator : public JitCodeCreator<int64_t> { ...@@ -76,7 +76,11 @@ class VBroadcastCreator : public JitCodeCreator<int64_t> {
return 96 + (w / YMM_FLOAT_BLOCK) * 16 * 8; return 96 + (w / YMM_FLOAT_BLOCK) * 16 * 8;
} }
std::unique_ptr<GenBase> CreateJitCode(const int64_t& w) const override { std::unique_ptr<GenBase> CreateJitCode(const int64_t& w) const override {
PADDLE_ENFORCE_GT(w, 0); PADDLE_ENFORCE_GT(
w, 0,
platform::errors::InvalidArgument(
"The width of VBroadcast should be larger than 0. But w is %d.",
w));
return make_unique<VBroadcastJitCode>(w, CodeSize(w)); return make_unique<VBroadcastJitCode>(w, CodeSize(w));
} }
}; };
......
...@@ -49,9 +49,14 @@ void GenBase::dumpCode(const unsigned char* code) const { ...@@ -49,9 +49,14 @@ void GenBase::dumpCode(const unsigned char* code) const {
void* GenBase::operator new(size_t size) { void* GenBase::operator new(size_t size) {
void* ptr; void* ptr;
constexpr size_t alignment = 32ul; constexpr size_t alignment = 32ul;
PADDLE_ENFORCE_EQ(posix_memalign(&ptr, alignment, size), 0, PADDLE_ENFORCE_EQ(
"GenBase Alloc %ld error!", size); posix_memalign(&ptr, alignment, size), 0,
PADDLE_ENFORCE(ptr, "Fail to allocate GenBase CPU memory: size = %d .", size); platform::errors::InvalidArgument(
"Jitcode generator (GenBase) allocate %ld memory error!", size));
PADDLE_ENFORCE_NOT_NULL(ptr, platform::errors::InvalidArgument(
"Fail to allocate jitcode generator "
"(GenBase) CPU memory: size = %d .",
size));
return ptr; return ptr;
} }
......
...@@ -66,7 +66,8 @@ const char* to_string(KernelType kt) { ...@@ -66,7 +66,8 @@ const char* to_string(KernelType kt) {
ONE_CASE(kEmbSeqPool); ONE_CASE(kEmbSeqPool);
ONE_CASE(kSgd); ONE_CASE(kSgd);
default: default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt); PADDLE_THROW(platform::errors::Unimplemented(
"JIT kernel do not support type: %d.", kt));
return "NOT JITKernel"; return "NOT JITKernel";
} }
return nullptr; return nullptr;
...@@ -79,7 +80,8 @@ const char* to_string(SeqPoolType tp) { ...@@ -79,7 +80,8 @@ const char* to_string(SeqPoolType tp) {
ONE_CASE(kAvg); ONE_CASE(kAvg);
ONE_CASE(kSqrt); ONE_CASE(kSqrt);
default: default:
PADDLE_THROW("Not support type: %d, or forget to add it.", tp); PADDLE_THROW(platform::errors::Unimplemented(
"SeqPool JIT kernel do not support type: %d.", tp));
return "NOT PoolType"; return "NOT PoolType";
} }
return nullptr; return nullptr;
...@@ -100,7 +102,8 @@ KernelType to_kerneltype(const std::string& act) { ...@@ -100,7 +102,8 @@ KernelType to_kerneltype(const std::string& act) {
} else if (lower == "tanh" || lower == "vtanh") { } else if (lower == "tanh" || lower == "vtanh") {
return kVTanh; return kVTanh;
} }
PADDLE_THROW("Not support type: %s, or forget to add this case", act); PADDLE_THROW(platform::errors::Unimplemented(
"Act JIT kernel do not support type: %s.", act));
return kNone; return kNone;
} }
...@@ -109,12 +112,19 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) { ...@@ -109,12 +112,19 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
int block, rest; int block, rest;
const auto groups = packed_groups(n, k, &block, &rest); const auto groups = packed_groups(n, k, &block, &rest);
std::for_each(groups.begin(), groups.end(), [&](int i) { std::for_each(groups.begin(), groups.end(), [&](int i) {
PADDLE_ENFORCE_GT(i, 0, "each element of groups should be larger than 0."); PADDLE_ENFORCE_GT(i, 0, platform::errors::InvalidArgument(
"Each element of groups should be larger than "
"0. However the element: %d doesn't satify.",
i));
}); });
int sum = std::accumulate(groups.begin(), groups.end(), 0); int sum = std::accumulate(groups.begin(), groups.end(), 0);
std::memset(dst, 0, k * sum * block * sizeof(float)); std::memset(dst, 0, k * sum * block * sizeof(float));
PADDLE_ENFORCE_GE(sum * block, n, PADDLE_ENFORCE_GE(sum * block, n,
"The packed n should be equal to or larger than n"); platform::errors::InvalidArgument(
"The packed n (sum * block) should be equal to or "
"larger than n (matmul row size). "
"However, the packed n is %d and n is %d.",
sum * block, n));
const int block_len = sizeof(float) * block; const int block_len = sizeof(float) * block;
int n_offset = 0; int n_offset = 0;
...@@ -136,7 +146,8 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) { ...@@ -136,7 +146,8 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
template <typename T> template <typename T>
typename std::enable_if<!std::is_same<T, float>::value>::type pack_weights( typename std::enable_if<!std::is_same<T, float>::value>::type pack_weights(
const T* src, T* dst, int n, int k) { const T* src, T* dst, int n, int k) {
PADDLE_THROW("Only support pack with float type."); PADDLE_THROW(platform::errors::Unimplemented(
"Only supports pack weights with float type."));
} }
} // namespace jit } // namespace jit
......
...@@ -85,8 +85,10 @@ inline const Kernel* GetReferKernel() { ...@@ -85,8 +85,10 @@ inline const Kernel* GetReferKernel() {
auto& ref_pool = ReferKernelPool::Instance().AllKernels(); auto& ref_pool = ReferKernelPool::Instance().AllKernels();
KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace()); KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey); auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(), PADDLE_ENFORCE_NE(
"Every Kernel should have reference function."); ref_iter, ref_pool.end(),
platform::errors::PreconditionNotMet(
"Every Refer Kernel of jitcode should have reference function."));
auto& ref_impls = ref_iter->second; auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) { for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get()); auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get());
...@@ -101,7 +103,9 @@ template <typename KernelTuple> ...@@ -101,7 +103,9 @@ template <typename KernelTuple>
inline typename KernelTuple::func_type GetReferFunc() { inline typename KernelTuple::func_type GetReferFunc() {
auto ker = GetReferKernel<KernelTuple>(); auto ker = GetReferKernel<KernelTuple>();
auto p = dynamic_cast<const ReferKernel<KernelTuple>*>(ker); auto p = dynamic_cast<const ReferKernel<KernelTuple>*>(ker);
PADDLE_ENFORCE(p, "The Refer kernel should exsit"); PADDLE_ENFORCE_NOT_NULL(p, platform::errors::InvalidArgument(
"Get the reference code of kernel in CPU "
"failed. The Refer kernel should exsit."));
return p->GetFunc(); return p->GetFunc();
} }
...@@ -132,7 +136,9 @@ std::vector<const Kernel*> GetAllCandidateKernels( ...@@ -132,7 +136,9 @@ std::vector<const Kernel*> GetAllCandidateKernels(
// The last implementation should be reference function on CPUPlace. // The last implementation should be reference function on CPUPlace.
auto ref = GetReferKernel<KernelTuple>(); auto ref = GetReferKernel<KernelTuple>();
PADDLE_ENFORCE(ref != nullptr, "Refer Kernel can not be empty."); PADDLE_ENFORCE_NOT_NULL(ref, platform::errors::InvalidArgument(
"Get all candicate kernel in CPU failed. "
"The Refer Kernel can not be empty."));
res.emplace_back(ref); res.emplace_back(ref);
return res; return res;
} }
...@@ -147,11 +153,14 @@ GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) { ...@@ -147,11 +153,14 @@ GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
std::string name = k->ImplType(); std::string name = k->ImplType();
if (name == "JitCode") { if (name == "JitCode") {
auto i = dynamic_cast<const GenBase*>(k); auto i = dynamic_cast<const GenBase*>(k);
PADDLE_ENFORCE(i, "jitcode kernel cast can not fail."); PADDLE_ENFORCE_NOT_NULL(i,
platform::errors::InvalidArgument(
"Generate jitcode kernel (GenBase) failed."));
res.emplace_back(std::make_pair(name, i->template getCode<Func>())); res.emplace_back(std::make_pair(name, i->template getCode<Func>()));
} else { } else {
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(k); auto i = dynamic_cast<const KernelMore<KernelTuple>*>(k);
PADDLE_ENFORCE(i, "kernel cast can not fail."); PADDLE_ENFORCE_NOT_NULL(i, platform::errors::InvalidArgument(
"Kernel cast (KernelMore) failed."));
res.emplace_back(std::make_pair(name, i->GetFunc())); res.emplace_back(std::make_pair(name, i->GetFunc()));
} }
} }
...@@ -173,7 +182,9 @@ template <typename KernelTuple, typename PlaceType = platform::CPUPlace> ...@@ -173,7 +182,9 @@ template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename KernelTuple::func_type GetDefaultBestFunc( typename KernelTuple::func_type GetDefaultBestFunc(
const typename KernelTuple::attr_type& attr) { const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncs<KernelTuple, PlaceType>(attr); auto funcs = GetAllCandidateFuncs<KernelTuple, PlaceType>(attr);
PADDLE_ENFORCE_GE(funcs.size(), 1UL); PADDLE_ENFORCE_GE(funcs.size(), 1UL,
platform::errors::InvalidArgument(
"The candicate jit kernel is at least one in CPU."));
// Here could do some runtime benchmark of this attr and return the best one. // Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one, // But yet just get the first one as the default best one,
// which is searched in order and tuned by offline. // which is searched in order and tuned by offline.
......
...@@ -95,7 +95,8 @@ void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT ...@@ -95,7 +95,8 @@ void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
} else if (type == kVIdentity) { } else if (type == kVIdentity) {
return KernelFuncs<VIdentityTuple<T>, CPUPlace>::Cache().At(d); return KernelFuncs<VIdentityTuple<T>, CPUPlace>::Cache().At(d);
} }
PADDLE_THROW("Not support type: %s", type); PADDLE_THROW(platform::errors::Unimplemented(
"Act JIT kernel do not support type: %s", type));
return nullptr; return nullptr;
} }
......
...@@ -103,11 +103,24 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { ...@@ -103,11 +103,24 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
template <typename T> template <typename T>
void EmbSeqPool(const T* table, const int64_t* idx, T* out, void EmbSeqPool(const T* table, const int64_t* idx, T* out,
const emb_seq_pool_attr_t* attr) { const emb_seq_pool_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->table_width * attr->index_width, attr->out_width); PADDLE_ENFORCE_EQ(
attr->table_width * attr->index_width, attr->out_width,
platform::errors::InvalidArgument(
"The attribute table_width * index_width of EmbSeqPool should "
"be equal to out_width. But table_width * index_width is %d, "
"out_width is %d.",
attr->table_width * attr->index_width, attr->out_width));
auto check_idx_value_valid = [&](int64_t i) { auto check_idx_value_valid = [&](int64_t i) {
PADDLE_ENFORCE_LT(idx[i], attr->table_height, "idx value: %d, i: %d", PADDLE_ENFORCE_LT(
idx[i], i); idx[i], attr->table_height,
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i); platform::errors::InvalidArgument(
"The idx shoud be lower than the attribute table_height of "
"EmbSeqPool. But %dth of idx is %d and table_height is %d.",
i, idx[i], attr->table_height));
PADDLE_ENFORCE_GE(idx[i], 0, platform::errors::InvalidArgument(
"The idx shoud be equal to or larger than "
"the 0. But %dth of idx is %d.",
i, idx[i]));
}; };
for (int64_t w = 0; w != attr->index_width; ++w) { for (int64_t w = 0; w != attr->index_width; ++w) {
...@@ -168,22 +181,50 @@ void Softmax(const T* x, T* y, int n, int bs, int remain = 1) { ...@@ -168,22 +181,50 @@ void Softmax(const T* x, T* y, int n, int bs, int remain = 1) {
template <typename T> template <typename T>
void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
T* out, const sgd_attr_t* attr) { T* out, const sgd_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width); PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width,
PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height); platform::errors::InvalidArgument(
"The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d.",
attr->param_width, attr->grad_width));
PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height,
platform::errors::InvalidArgument(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d.",
attr->selected_rows_size, attr->grad_height));
T scalar = -lr[0]; T scalar = -lr[0];
int width = attr->grad_width; int width = attr->grad_width;
if (out == param) { if (out == param) {
for (int64_t i = 0; i < attr->selected_rows_size; ++i) { for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
auto h_idx = rows[i]; auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, attr->param_height); PADDLE_ENFORCE_LT(h_idx, attr->param_height,
PADDLE_ENFORCE_GE(h_idx, 0); platform::errors::InvalidArgument(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d.",
i, h_idx, attr->param_height));
PADDLE_ENFORCE_GE(h_idx, 0, platform::errors::InvalidArgument(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d.",
i, h_idx));
VAXPY(scalar, grad + i * width, out + h_idx * width, width); VAXPY(scalar, grad + i * width, out + h_idx * width, width);
} }
} else { } else {
for (int64_t i = 0; i < attr->selected_rows_size; ++i) { for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
auto h_idx = rows[i]; auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, attr->param_height); PADDLE_ENFORCE_LT(h_idx, attr->param_height,
PADDLE_ENFORCE_GE(h_idx, 0); platform::errors::InvalidArgument(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d.",
i, h_idx, attr->param_height));
PADDLE_ENFORCE_GE(h_idx, 0, platform::errors::InvalidArgument(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d.",
i, h_idx));
VScal(&scalar, grad + i * width, out + h_idx * width, width); VScal(&scalar, grad + i * width, out + h_idx * width, width);
VAdd(param + h_idx * width, out + h_idx * width, out + h_idx * width, VAdd(param + h_idx * width, out + h_idx * width, out + h_idx * width,
width); width);
......
...@@ -147,7 +147,8 @@ void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT ...@@ -147,7 +147,8 @@ void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT
} else if (type == kVIdentity) { } else if (type == kVIdentity) {
return VIdentity<T>; return VIdentity<T>;
} }
PADDLE_THROW("Not support type: %s", type); PADDLE_THROW(platform::errors::Unimplemented(
"Act JIT kernel do not support type: %s.", type));
return nullptr; return nullptr;
} }
...@@ -465,12 +466,25 @@ void Softmax(const T* x, T* y, int n, int bs = 1, int remain = 1) { ...@@ -465,12 +466,25 @@ void Softmax(const T* x, T* y, int n, int bs = 1, int remain = 1) {
template <typename T> template <typename T>
void EmbSeqPool(const T* table, const int64_t* idx, T* out, void EmbSeqPool(const T* table, const int64_t* idx, T* out,
const emb_seq_pool_attr_t* attr) { const emb_seq_pool_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->table_width * attr->index_width, attr->out_width); PADDLE_ENFORCE_EQ(
attr->table_width * attr->index_width, attr->out_width,
platform::errors::InvalidArgument(
"The attribute table_width * index_width of EmbSeqPool should "
"be equal to out_width. But table_width * index_width is %d and "
"out_width is %d.",
attr->table_width * attr->index_width, attr->out_width));
auto check_idx_value_valid = [&](int64_t i) { auto check_idx_value_valid = [&](int64_t i) {
PADDLE_ENFORCE_LT(idx[i], attr->table_height, "idx value: %d, i: %d", PADDLE_ENFORCE_LT(
idx[i], i); idx[i], attr->table_height,
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i); platform::errors::InvalidArgument(
"The idx shoud be lower than the attribute table_height of "
"EmbSeqPool. But %dth of idx is %d and table_height is %d.",
i, idx[i], attr->table_height));
PADDLE_ENFORCE_GE(idx[i], 0, platform::errors::InvalidArgument(
"The idx shoud be equal to or larger than "
"the 0. But %dth of idx is %d.",
i, idx[i]));
}; };
for (int64_t w = 0; w != attr->index_width; ++w) { for (int64_t w = 0; w != attr->index_width; ++w) {
...@@ -505,12 +519,31 @@ void EmbSeqPool(const T* table, const int64_t* idx, T* out, ...@@ -505,12 +519,31 @@ void EmbSeqPool(const T* table, const int64_t* idx, T* out,
template <typename T> template <typename T>
void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
T* out, const sgd_attr_t* attr) { T* out, const sgd_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width); PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width,
PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height); platform::errors::InvalidArgument(
"The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d.",
attr->param_width, attr->grad_width));
PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height,
platform::errors::InvalidArgument(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d.",
attr->selected_rows_size, attr->grad_height));
for (int64_t i = 0; i < attr->selected_rows_size; ++i) { for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
auto h_idx = rows[i]; auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, attr->param_height); PADDLE_ENFORCE_LT(h_idx, attr->param_height,
PADDLE_ENFORCE_GE(h_idx, 0); platform::errors::InvalidArgument(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d.",
i, h_idx, attr->param_height));
PADDLE_ENFORCE_GE(h_idx, 0, platform::errors::InvalidArgument(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d.",
i, h_idx));
for (int64_t j = 0; j < attr->grad_width; ++j) { for (int64_t j = 0; j < attr->grad_width; ++j) {
out[h_idx * attr->grad_width + j] = out[h_idx * attr->grad_width + j] =
param[h_idx * attr->grad_width + j] - param[h_idx * attr->grad_width + j] -
......
...@@ -850,8 +850,15 @@ void TestKernelSgd() { ...@@ -850,8 +850,15 @@ void TestKernelSgd() {
const T lr = 0.1; const T lr = 0.1;
auto UnDuplicatedRandomVec = [](int n, const int64_t lower, auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
const int64_t upper) -> std::vector<int64_t> { const int64_t upper) -> std::vector<int64_t> {
PADDLE_ENFORCE_LE(static_cast<size_t>(upper - lower), n - 1); PADDLE_ENFORCE_LE(static_cast<size_t>(upper - lower), n - 1,
PADDLE_ENFORCE_GT(n, 0); paddle::platform::errors::InvalidArgument(
"The range of Sgd (upper - lower) should be lower "
"than n-1 (Sgd size -1). But the upper - lower is %d "
"and n-1 is %d.",
static_cast<size_t>(upper - lower), n - 1));
PADDLE_ENFORCE_GT(
n, 0, paddle::platform::errors::InvalidArgument(
"The Sgd size should be larger than 0. But the n is %d.", n));
std::vector<int64_t> all, out; std::vector<int64_t> all, out;
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
all.push_back(i); all.push_back(i);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册