未验证 提交 36ed83d2 编写于 作者: G GaoWei8 提交者: GitHub

Refine PADDLE_ENFORCE (#27360)

* refine PADDLE_ENFORCE
上级 effd51b6
......@@ -47,8 +47,8 @@ void OpTester::Init(const OpTesterConfig &config) {
CreateInputVarDesc();
CreateOutputVarDesc();
} else {
PADDLE_THROW(platform::errors::NotFound("Operator '%s' is not registered.",
config_.op_type));
PADDLE_THROW(platform::errors::NotFound(
"Operator '%s' is not registered in OpTester.", config_.op_type));
}
if (config_.device_id >= 0) {
......@@ -81,7 +81,8 @@ void OpTester::Run() {
platform::EnableProfiler(platform::ProfilerState::kAll);
platform::SetDeviceId(config_.device_id);
#else
PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
PADDLE_THROW(platform::errors::PermissionDenied(
"'CUDAPlace' is not supported in CPU only device."));
#endif
}
......@@ -162,7 +163,8 @@ framework::proto::VarType::Type OpTester::TransToVarType(std::string str) {
} else if (str == "fp64") {
return framework::proto::VarType::FP64;
} else {
PADDLE_THROW("Unsupported dtype %s.", str.c_str());
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported dtype %s in OpTester.", str.c_str()));
}
}
......@@ -233,8 +235,8 @@ void OpTester::CreateOpDesc() {
case framework::proto::AttrType::INTS:
case framework::proto::AttrType::FLOATS:
case framework::proto::AttrType::STRINGS:
PADDLE_THROW(
platform::errors::Unimplemented("Not supported STRINGS type yet."));
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported STRINGS type in OpTester yet."));
break;
case framework::proto::AttrType::LONG: {
int64_t value = StringTo<int64_t>(value_str);
......@@ -242,7 +244,8 @@ void OpTester::CreateOpDesc() {
} break;
case framework::proto::AttrType::LONGS:
default:
PADDLE_THROW("Unsupport attr type %d", type);
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport attr type %d in OpTester.", type));
}
}
}
......@@ -299,7 +302,8 @@ void OpTester::SetupTensor(framework::LoDTensor *tensor,
}
is.close();
} else {
PADDLE_THROW("Unsupported initializer %s.", initializer.c_str());
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported initializer %s in OpTester.", initializer.c_str()));
}
if (!platform::is_cpu_place(place_)) {
......@@ -351,7 +355,8 @@ void OpTester::CreateVariables(framework::Scope *scope) {
static_cast<double>(1.0), item.second.initializer,
item.second.filename);
} else {
PADDLE_THROW("Unsupported dtype %d.", data_type);
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported dtype %d in OpTester.", data_type));
}
VLOG(3) << "Set lod for tensor " << var_name;
......@@ -473,7 +478,8 @@ std::string OpTester::DebugString() {
<< "\n";
} break;
default:
PADDLE_THROW("Unsupport attr type %d", attr_type);
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport attr type %d in OpTester.", attr_type));
}
ss << GenSpaces(--count) << "}\n";
}
......@@ -484,8 +490,10 @@ std::string OpTester::DebugString() {
TEST(op_tester, base) {
if (!FLAGS_op_config_list.empty()) {
std::ifstream fin(FLAGS_op_config_list, std::ios::in | std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s",
FLAGS_op_config_list.c_str());
PADDLE_ENFORCE_EQ(
static_cast<bool>(fin), true,
platform::errors::InvalidArgument("OpTester cannot open file %s",
FLAGS_op_config_list.c_str()));
std::vector<OpTesterConfig> op_configs;
while (!fin.eof()) {
VLOG(4) << "Reading config " << op_configs.size() << "...";
......
......@@ -78,7 +78,8 @@ void OpInputConfig::ParseDType(std::istream& is) {
} else if (dtype_str == "fp64" || dtype_str == "double") {
dtype = "fp64";
} else {
PADDLE_THROW("Unsupported dtype %s", dtype_str.c_str());
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported dtype %s in OpInputConfig.", dtype_str.c_str()));
}
VLOG(4) << "dtype of input " << name << " is: " << dtype;
}
......@@ -91,7 +92,9 @@ void OpInputConfig::ParseInitializer(std::istream& is) {
const std::vector<std::string> supported_initializers = {"random", "natural",
"zeros", "file"};
if (!Has(supported_initializers, initializer_str)) {
PADDLE_THROW("Unsupported initializer %s", initializer_str.c_str());
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported initializer %s in OpInputConfig.",
initializer_str.c_str()));
}
initializer = initializer_str;
......@@ -126,7 +129,12 @@ void OpInputConfig::ParseLoD(std::istream& is) {
}
}
EraseEndSep(&lod_str);
PADDLE_ENFORCE_GE(lod_str.length(), 4U);
PADDLE_ENFORCE_GE(
lod_str.length(), 4U,
platform::errors::InvalidArgument(
"The length of lod string should be "
"equal to or larger than 4. But length of lod string is %zu.",
lod_str.length()));
VLOG(4) << "lod: " << lod_str << ", length: " << lod_str.length();
// Parse the lod_str
......@@ -153,8 +161,10 @@ void OpInputConfig::ParseLoD(std::istream& is) {
OpTesterConfig::OpTesterConfig(const std::string& filename) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s",
filename.c_str());
PADDLE_ENFORCE_EQ(
static_cast<bool>(fin), true,
platform::errors::InvalidArgument("OpTesterConfig cannot open file %s.",
filename.c_str()));
Init(fin);
}
......
......@@ -136,7 +136,6 @@ void BenchAllImpls(const typename KernelTuple::attr_type& attr, Args... args) {
}
using Tensor = paddle::framework::Tensor;
template <typename KernelTuple, typename PlaceType>
void BenchKernelXYZN() {
using T = typename KernelTuple::data_type;
......@@ -320,8 +319,15 @@ void BenchKernelSgd() {
const T lr = 0.1;
auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
const int64_t upper) -> std::vector<int64_t> {
PADDLE_ENFORCE_LE(static_cast<size_t>(upper - lower), n - 1);
PADDLE_ENFORCE_GT(n, 0);
PADDLE_ENFORCE_LE(
static_cast<size_t>(upper - lower), n - 1,
paddle::platform::errors::InvalidArgument(
"The range of Sgd (upper - lower) should be equal to or lower "
"than n-1 (Sgd size -1). But upper - lower is %d and n-1 is %d.",
static_cast<size_t>(upper - lower), (n - 1)));
PADDLE_ENFORCE_GT(
n, 0, paddle::platform::errors::InvalidArgument(
"The Sgd size should be larger than 0. But the n is %d.", n));
std::vector<int64_t> all, out;
for (int i = 0; i < n; ++i) {
all.push_back(i);
......
......@@ -132,11 +132,31 @@ class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
}
std::unique_ptr<GenBase> CreateJitCode(
const emb_seq_pool_attr_t& attr) const override {
PADDLE_ENFORCE_GT(attr.table_height, 0);
PADDLE_ENFORCE_GT(attr.table_width, 0);
PADDLE_ENFORCE_GT(attr.index_height, 0);
PADDLE_ENFORCE_GT(attr.index_width, 0);
PADDLE_ENFORCE_GT(attr.out_width, 0);
PADDLE_ENFORCE_GT(attr.table_height, 0,
platform::errors::InvalidArgument(
"The attribute table_height of EmbSeqPool should "
"be larger than 0. But it is %d.",
attr.table_height));
PADDLE_ENFORCE_GT(attr.table_width, 0,
platform::errors::InvalidArgument(
"The attribute table_width of EmbSeqPool should "
"be larger than 0. But it is %d.",
attr.table_width));
PADDLE_ENFORCE_GT(attr.index_height, 0,
platform::errors::InvalidArgument(
"The attribute index_height of EmbSeqPool should "
"be larger than 0. But it is %d.",
attr.index_height));
PADDLE_ENFORCE_GT(attr.index_width, 0,
platform::errors::InvalidArgument(
"The attribute index_width of EmbSeqPool should "
"be larger than 0. But it is %d.",
attr.index_width));
PADDLE_ENFORCE_GT(attr.out_width, 0,
platform::errors::InvalidArgument(
"The attribute out_width of EmbSeqPool should be "
"larger than 0. But it is %d.",
attr.out_width));
return make_unique<EmbSeqPoolJitCode>(attr, CodeSize(attr));
}
};
......
......@@ -29,7 +29,11 @@ void MatMulJitCode::genCode() {
preCode();
int block, rest;
const auto groups = packed_groups(n_, k_, &block, &rest);
PADDLE_ENFORCE_GT(groups.front(), 0);
PADDLE_ENFORCE_GT(
groups.front(), 0,
platform::errors::InvalidArgument("The number of rest registers should "
"be larger than 0. But it is %d.",
groups.front()));
const int block_len = sizeof(float) * block;
const int x_reg_idx = (block == ZMM_FLOAT_BLOCK ? 32 : 16) - 1;
......@@ -118,9 +122,21 @@ class MatMulCreator : public JitCodeCreator<matmul_attr_t> {
}
std::unique_ptr<GenBase> CreateJitCode(
const matmul_attr_t& attr) const override {
PADDLE_ENFORCE_GT(attr.m, 0);
PADDLE_ENFORCE_GT(attr.n, 0);
PADDLE_ENFORCE_GT(attr.k, 0);
PADDLE_ENFORCE_GT(
attr.m, 0, platform::errors::InvalidArgument(
"The attribute m (first matrix's row) of MatMul should "
"be larger than 0. But it is %d.",
attr.m));
PADDLE_ENFORCE_GT(
attr.n, 0, platform::errors::InvalidArgument(
"The attribute n (first matrix's col) of MatMul should "
"be larger than 0. But it is %d.",
attr.n));
PADDLE_ENFORCE_GT(
attr.k, 0, platform::errors::InvalidArgument(
"The attribute k (second matrix's col) of MatMul should "
"be larger than 0. But it is %d.",
attr.k));
return make_unique<MatMulJitCode>(attr, CodeSize(attr));
}
};
......
......@@ -33,7 +33,10 @@ class MatMulJitCode : public JitCode {
size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), m_(attr.m), n_(attr.n), k_(attr.k) {
PADDLE_ENFORCE_EQ(m_, 1, "Only support m==1 yet");
PADDLE_ENFORCE_EQ(m_, 1, platform::errors::Unimplemented(
"Jitcode of matmul only support m==1 (first "
"matrix's row) now. But m is %d.",
m_));
this->genCode();
}
......
......@@ -70,8 +70,14 @@ class SeqPoolCreator : public JitCodeCreator<seq_pool_attr_t> {
}
std::unique_ptr<GenBase> CreateJitCode(
const seq_pool_attr_t& attr) const override {
PADDLE_ENFORCE_GT(attr.w, 0);
PADDLE_ENFORCE_GT(attr.h, 0);
PADDLE_ENFORCE_GT(attr.w, 0, platform::errors::InvalidArgument(
"The attribute width of SeqPool should "
"be larger than 0. But it is %d.",
attr.w));
PADDLE_ENFORCE_GT(attr.h, 0, platform::errors::InvalidArgument(
"The attribute height of SeqPool should "
"be larger than 0. But it is %d.",
attr.h));
return make_unique<SeqPoolJitCode>(attr, CodeSize(attr));
}
};
......
......@@ -127,8 +127,13 @@ class SeqPoolJitCode : public JitCode {
vmovss(xmm_t(reg_idx + max_num_regs), ptr[reg_ptr_src_i]);
reg_idx++;
}
PADDLE_ENFORCE_EQ(reg_idx, rest_used_num_regs,
"All heights should use same regs");
PADDLE_ENFORCE_EQ(
reg_idx, rest_used_num_regs,
platform::errors::InvalidArgument(
"All heights of SeqPool should use the same number of registers."
"It equals to the numbr of rest registers. But use %d registers "
"and the numbr of rest registers is %d.",
reg_idx, rest_used_num_regs));
for (int i = 0; i < reg_idx; ++i) {
vaddps(xmm_t(i), xmm_t(i), xmm_t(i + max_num_regs));
}
......
......@@ -116,9 +116,24 @@ class SgdCreator : public JitCodeCreator<sgd_attr_t> {
size_t CodeSize(const sgd_attr_t& attr) const override { return 96 + 32 * 8; }
std::unique_ptr<GenBase> CreateJitCode(
const sgd_attr_t& attr) const override {
PADDLE_ENFORCE_EQ(attr.param_width, attr.grad_width);
PADDLE_ENFORCE_LE(attr.selected_rows_size, attr.grad_height);
PADDLE_ENFORCE_GE(attr.selected_rows_size, 0);
PADDLE_ENFORCE_EQ(attr.param_width, attr.grad_width,
platform::errors::InvalidArgument(
"The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d.",
attr.param_width, attr.grad_width));
PADDLE_ENFORCE_LE(attr.selected_rows_size, attr.grad_height,
platform::errors::InvalidArgument(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d.",
attr.selected_rows_size, attr.grad_height));
PADDLE_ENFORCE_GE(
attr.selected_rows_size, 0,
platform::errors::InvalidArgument(
"The attribute selected_rows_size of Sgd should be "
"equal to or larger than 0. But selected_rows_size is %d.",
attr.selected_rows_size));
return make_unique<SgdJitCode>(attr, CodeSize(attr));
}
};
......
......@@ -76,7 +76,11 @@ class VBroadcastCreator : public JitCodeCreator<int64_t> {
return 96 + (w / YMM_FLOAT_BLOCK) * 16 * 8;
}
std::unique_ptr<GenBase> CreateJitCode(const int64_t& w) const override {
PADDLE_ENFORCE_GT(w, 0);
PADDLE_ENFORCE_GT(
w, 0,
platform::errors::InvalidArgument(
"The width of VBroadcast should be larger than 0. But w is %d.",
w));
return make_unique<VBroadcastJitCode>(w, CodeSize(w));
}
};
......
......@@ -49,9 +49,14 @@ void GenBase::dumpCode(const unsigned char* code) const {
void* GenBase::operator new(size_t size) {
void* ptr;
constexpr size_t alignment = 32ul;
PADDLE_ENFORCE_EQ(posix_memalign(&ptr, alignment, size), 0,
"GenBase Alloc %ld error!", size);
PADDLE_ENFORCE(ptr, "Fail to allocate GenBase CPU memory: size = %d .", size);
PADDLE_ENFORCE_EQ(
posix_memalign(&ptr, alignment, size), 0,
platform::errors::InvalidArgument(
"Jitcode generator (GenBase) allocate %ld memory error!", size));
PADDLE_ENFORCE_NOT_NULL(ptr, platform::errors::InvalidArgument(
"Fail to allocate jitcode generator "
"(GenBase) CPU memory: size = %d .",
size));
return ptr;
}
......
......@@ -66,7 +66,8 @@ const char* to_string(KernelType kt) {
ONE_CASE(kEmbSeqPool);
ONE_CASE(kSgd);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
PADDLE_THROW(platform::errors::Unimplemented(
"JIT kernel do not support type: %d.", kt));
return "NOT JITKernel";
}
return nullptr;
......@@ -79,7 +80,8 @@ const char* to_string(SeqPoolType tp) {
ONE_CASE(kAvg);
ONE_CASE(kSqrt);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", tp);
PADDLE_THROW(platform::errors::Unimplemented(
"SeqPool JIT kernel do not support type: %d.", tp));
return "NOT PoolType";
}
return nullptr;
......@@ -100,7 +102,8 @@ KernelType to_kerneltype(const std::string& act) {
} else if (lower == "tanh" || lower == "vtanh") {
return kVTanh;
}
PADDLE_THROW("Not support type: %s, or forget to add this case", act);
PADDLE_THROW(platform::errors::Unimplemented(
"Act JIT kernel do not support type: %s.", act));
return kNone;
}
......@@ -109,12 +112,19 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
int block, rest;
const auto groups = packed_groups(n, k, &block, &rest);
std::for_each(groups.begin(), groups.end(), [&](int i) {
PADDLE_ENFORCE_GT(i, 0, "each element of groups should be larger than 0.");
PADDLE_ENFORCE_GT(i, 0, platform::errors::InvalidArgument(
"Each element of groups should be larger than "
"0. However the element: %d doesn't satify.",
i));
});
int sum = std::accumulate(groups.begin(), groups.end(), 0);
std::memset(dst, 0, k * sum * block * sizeof(float));
PADDLE_ENFORCE_GE(sum * block, n,
"The packed n should be equal to or larger than n");
platform::errors::InvalidArgument(
"The packed n (sum * block) should be equal to or "
"larger than n (matmul row size). "
"However, the packed n is %d and n is %d.",
sum * block, n));
const int block_len = sizeof(float) * block;
int n_offset = 0;
......@@ -136,7 +146,8 @@ void pack_weights<float>(const float* src, float* dst, int n, int k) {
template <typename T>
typename std::enable_if<!std::is_same<T, float>::value>::type pack_weights(
const T* src, T* dst, int n, int k) {
PADDLE_THROW("Only support pack with float type.");
PADDLE_THROW(platform::errors::Unimplemented(
"Only supports pack weights with float type."));
}
} // namespace jit
......
......@@ -85,8 +85,10 @@ inline const Kernel* GetReferKernel() {
auto& ref_pool = ReferKernelPool::Instance().AllKernels();
KernelKey kkey(KernelTuple::kernel_type, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function.");
PADDLE_ENFORCE_NE(
ref_iter, ref_pool.end(),
platform::errors::PreconditionNotMet(
"Every Refer Kernel of jitcode should have reference function."));
auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuple>*>(impl.get());
......@@ -101,7 +103,9 @@ template <typename KernelTuple>
inline typename KernelTuple::func_type GetReferFunc() {
auto ker = GetReferKernel<KernelTuple>();
auto p = dynamic_cast<const ReferKernel<KernelTuple>*>(ker);
PADDLE_ENFORCE(p, "The Refer kernel should exsit");
PADDLE_ENFORCE_NOT_NULL(p, platform::errors::InvalidArgument(
"Get the reference code of kernel in CPU "
"failed. The Refer kernel should exsit."));
return p->GetFunc();
}
......@@ -132,7 +136,9 @@ std::vector<const Kernel*> GetAllCandidateKernels(
// The last implementation should be reference function on CPUPlace.
auto ref = GetReferKernel<KernelTuple>();
PADDLE_ENFORCE(ref != nullptr, "Refer Kernel can not be empty.");
PADDLE_ENFORCE_NOT_NULL(ref, platform::errors::InvalidArgument(
"Get all candicate kernel in CPU failed. "
"The Refer Kernel can not be empty."));
res.emplace_back(ref);
return res;
}
......@@ -147,11 +153,14 @@ GetAllCandidateFuncsWithTypes(const typename KernelTuple::attr_type& attr) {
std::string name = k->ImplType();
if (name == "JitCode") {
auto i = dynamic_cast<const GenBase*>(k);
PADDLE_ENFORCE(i, "jitcode kernel cast can not fail.");
PADDLE_ENFORCE_NOT_NULL(i,
platform::errors::InvalidArgument(
"Generate jitcode kernel (GenBase) failed."));
res.emplace_back(std::make_pair(name, i->template getCode<Func>()));
} else {
auto i = dynamic_cast<const KernelMore<KernelTuple>*>(k);
PADDLE_ENFORCE(i, "kernel cast can not fail.");
PADDLE_ENFORCE_NOT_NULL(i, platform::errors::InvalidArgument(
"Kernel cast (KernelMore) failed."));
res.emplace_back(std::make_pair(name, i->GetFunc()));
}
}
......@@ -173,7 +182,9 @@ template <typename KernelTuple, typename PlaceType = platform::CPUPlace>
typename KernelTuple::func_type GetDefaultBestFunc(
const typename KernelTuple::attr_type& attr) {
auto funcs = GetAllCandidateFuncs<KernelTuple, PlaceType>(attr);
PADDLE_ENFORCE_GE(funcs.size(), 1UL);
PADDLE_ENFORCE_GE(funcs.size(), 1UL,
platform::errors::InvalidArgument(
"The candicate jit kernel is at least one in CPU."));
// Here could do some runtime benchmark of this attr and return the best one.
// But yet just get the first one as the default best one,
// which is searched in order and tuned by offline.
......
......@@ -95,7 +95,8 @@ void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
} else if (type == kVIdentity) {
return KernelFuncs<VIdentityTuple<T>, CPUPlace>::Cache().At(d);
}
PADDLE_THROW("Not support type: %s", type);
PADDLE_THROW(platform::errors::Unimplemented(
"Act JIT kernel do not support type: %s", type));
return nullptr;
}
......
......@@ -103,11 +103,24 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
template <typename T>
void EmbSeqPool(const T* table, const int64_t* idx, T* out,
const emb_seq_pool_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->table_width * attr->index_width, attr->out_width);
PADDLE_ENFORCE_EQ(
attr->table_width * attr->index_width, attr->out_width,
platform::errors::InvalidArgument(
"The attribute table_width * index_width of EmbSeqPool should "
"be equal to out_width. But table_width * index_width is %d, "
"out_width is %d.",
attr->table_width * attr->index_width, attr->out_width));
auto check_idx_value_valid = [&](int64_t i) {
PADDLE_ENFORCE_LT(idx[i], attr->table_height, "idx value: %d, i: %d",
idx[i], i);
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
PADDLE_ENFORCE_LT(
idx[i], attr->table_height,
platform::errors::InvalidArgument(
"The idx shoud be lower than the attribute table_height of "
"EmbSeqPool. But %dth of idx is %d and table_height is %d.",
i, idx[i], attr->table_height));
PADDLE_ENFORCE_GE(idx[i], 0, platform::errors::InvalidArgument(
"The idx shoud be equal to or larger than "
"the 0. But %dth of idx is %d.",
i, idx[i]));
};
for (int64_t w = 0; w != attr->index_width; ++w) {
......@@ -168,22 +181,50 @@ void Softmax(const T* x, T* y, int n, int bs, int remain = 1) {
template <typename T>
void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
T* out, const sgd_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width);
PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height);
PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width,
platform::errors::InvalidArgument(
"The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d.",
attr->param_width, attr->grad_width));
PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height,
platform::errors::InvalidArgument(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d.",
attr->selected_rows_size, attr->grad_height));
T scalar = -lr[0];
int width = attr->grad_width;
if (out == param) {
for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, attr->param_height);
PADDLE_ENFORCE_GE(h_idx, 0);
PADDLE_ENFORCE_LT(h_idx, attr->param_height,
platform::errors::InvalidArgument(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d.",
i, h_idx, attr->param_height));
PADDLE_ENFORCE_GE(h_idx, 0, platform::errors::InvalidArgument(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d.",
i, h_idx));
VAXPY(scalar, grad + i * width, out + h_idx * width, width);
}
} else {
for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, attr->param_height);
PADDLE_ENFORCE_GE(h_idx, 0);
PADDLE_ENFORCE_LT(h_idx, attr->param_height,
platform::errors::InvalidArgument(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d.",
i, h_idx, attr->param_height));
PADDLE_ENFORCE_GE(h_idx, 0, platform::errors::InvalidArgument(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d.",
i, h_idx));
VScal(&scalar, grad + i * width, out + h_idx * width, width);
VAdd(param + h_idx * width, out + h_idx * width, out + h_idx * width,
width);
......
......@@ -147,7 +147,8 @@ void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT
} else if (type == kVIdentity) {
return VIdentity<T>;
}
PADDLE_THROW("Not support type: %s", type);
PADDLE_THROW(platform::errors::Unimplemented(
"Act JIT kernel do not support type: %s.", type));
return nullptr;
}
......@@ -465,12 +466,25 @@ void Softmax(const T* x, T* y, int n, int bs = 1, int remain = 1) {
template <typename T>
void EmbSeqPool(const T* table, const int64_t* idx, T* out,
const emb_seq_pool_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->table_width * attr->index_width, attr->out_width);
PADDLE_ENFORCE_EQ(
attr->table_width * attr->index_width, attr->out_width,
platform::errors::InvalidArgument(
"The attribute table_width * index_width of EmbSeqPool should "
"be equal to out_width. But table_width * index_width is %d and "
"out_width is %d.",
attr->table_width * attr->index_width, attr->out_width));
auto check_idx_value_valid = [&](int64_t i) {
PADDLE_ENFORCE_LT(idx[i], attr->table_height, "idx value: %d, i: %d",
idx[i], i);
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
PADDLE_ENFORCE_LT(
idx[i], attr->table_height,
platform::errors::InvalidArgument(
"The idx shoud be lower than the attribute table_height of "
"EmbSeqPool. But %dth of idx is %d and table_height is %d.",
i, idx[i], attr->table_height));
PADDLE_ENFORCE_GE(idx[i], 0, platform::errors::InvalidArgument(
"The idx shoud be equal to or larger than "
"the 0. But %dth of idx is %d.",
i, idx[i]));
};
for (int64_t w = 0; w != attr->index_width; ++w) {
......@@ -505,12 +519,31 @@ void EmbSeqPool(const T* table, const int64_t* idx, T* out,
template <typename T>
void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
T* out, const sgd_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width);
PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height);
PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width,
platform::errors::InvalidArgument(
"The attribute param_width of Sgd should be "
"equal to the attribute grad_width. But param_width "
"is %d and grad_width is %d.",
attr->param_width, attr->grad_width));
PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height,
platform::errors::InvalidArgument(
"The attribute selected_rows_size of Sgd should be "
"equal to or less than the attribute grad_height. "
"But selected_rows_size is %d and grad_height is %d.",
attr->selected_rows_size, attr->grad_height));
for (int64_t i = 0; i < attr->selected_rows_size; ++i) {
auto h_idx = rows[i];
PADDLE_ENFORCE_LT(h_idx, attr->param_height);
PADDLE_ENFORCE_GE(h_idx, 0);
PADDLE_ENFORCE_LT(h_idx, attr->param_height,
platform::errors::InvalidArgument(
"The rows of Sgd should be "
"less than the attribute. But %dth of rows "
"is %d and grad_width is %d.",
i, h_idx, attr->param_height));
PADDLE_ENFORCE_GE(h_idx, 0, platform::errors::InvalidArgument(
"The rows of Sgd should be "
"larger than 0. But %dth of rows "
"is %d.",
i, h_idx));
for (int64_t j = 0; j < attr->grad_width; ++j) {
out[h_idx * attr->grad_width + j] =
param[h_idx * attr->grad_width + j] -
......
......@@ -850,8 +850,15 @@ void TestKernelSgd() {
const T lr = 0.1;
auto UnDuplicatedRandomVec = [](int n, const int64_t lower,
const int64_t upper) -> std::vector<int64_t> {
PADDLE_ENFORCE_LE(static_cast<size_t>(upper - lower), n - 1);
PADDLE_ENFORCE_GT(n, 0);
PADDLE_ENFORCE_LE(static_cast<size_t>(upper - lower), n - 1,
paddle::platform::errors::InvalidArgument(
"The range of Sgd (upper - lower) should be lower "
"than n-1 (Sgd size -1). But the upper - lower is %d "
"and n-1 is %d.",
static_cast<size_t>(upper - lower), n - 1));
PADDLE_ENFORCE_GT(
n, 0, paddle::platform::errors::InvalidArgument(
"The Sgd size should be larger than 0. But the n is %d.", n));
std::vector<int64_t> all, out;
for (int i = 0; i < n; ++i) {
all.push_back(i);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册