提交 f5c2d175 编写于 作者: Y Yang Yu

Refine

上级 6f5e64af
......@@ -59,15 +59,16 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
static void CheckTensorNANOrInf(const std::string& name,
const framework::Tensor& tensor) {
if (tensor.type().hash_code() != typeid(float).hash_code() &&
tensor.type().hash_code() != typeid(double).hash_code()) {
if (tensor.memory_size() == 0) {
return;
}
if (tensor.memory_size() == 0) {
if (tensor.type().hash_code() != typeid(float).hash_code() &&
tensor.type().hash_code() != typeid(double).hash_code()) {
return;
}
PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name);
PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN", name);
PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN, %p", name,
&tensor);
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
......
......@@ -134,8 +134,17 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) {
#endif
offset_ = 0;
}
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
void* buf = reinterpret_cast<void*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
if (type.hash_code() == typeid(float).hash_code() ||
type.hash_code() == typeid(double).hash_code()) {
float* tmp = (float*)(buf);
for (int64_t i = 0; i < numel(); ++i) {
tmp[i] = NAN;
}
}
return buf;
}
inline void* Tensor::mutable_data(platform::Place place) {
......
......@@ -35,6 +35,7 @@ class Variable {
template <typename T>
T* GetMutable() {
if (!IsType<T>()) {
VLOG(10) << "Resetting " << *this->name_;
holder_.reset(new PlaceholderImpl<T>(new T()));
}
return static_cast<T*>(holder_->Ptr());
......
......@@ -51,6 +51,7 @@ class FillConstantOp : public framework::OperatorBase {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place);
VLOG(10) << "FillConstant to " << &out;
math::set_constant(dev_ctx, &out, value);
}
};
......
......@@ -116,9 +116,10 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
auto height = dout_tensor.dims()[0];
auto slice = dx_tensor.Slice(0, static_cast<int>(height));
framework::CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx, &slice);
if (dx_tensor.dims()[0] < height) {
VLOG(10) << dx_tensor.dims()[0] << ", " << height;
if (dx_tensor.dims()[0] > height) {
auto rest_tensor = dx_tensor.Slice(
static_cast<int>(height), static_cast<int>(dout_tensor.dims()[0]));
static_cast<int>(height), static_cast<int>(dx_tensor.dims()[0]));
math::set_constant(dev_ctx, &rest_tensor, 0.0f);
}
}
......
......@@ -38,11 +38,9 @@ class SumKernel : public framework::OpKernel<T> {
if (out_var->IsType<framework::LoDTensor>()) {
auto *out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto result = EigenVector<T>::Flatten(*out);
if (!in_place) {
out->mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> constant_functor;
constant_functor(context.template device_context<DeviceContext>(), out,
0.0);
......
......@@ -130,9 +130,9 @@ class ReadFromArrayOp : public ArrayOp {
auto &x_array = x->Get<framework::LoDTensorArray>();
auto *out = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out != nullptr, "Out must be set");
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
size_t offset = GetOffset(scope, place);
if (offset < x_array.size()) {
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
......
......@@ -194,14 +194,27 @@ class WhileGradOp : public framework::OperatorBase {
}
}
auto check_var_no_nan = [](const framework::Scope &scope,
const std::string &var_name) {
auto *var = scope.FindVar(var_name);
if (var->IsType<LoDTensor>()) {
VLOG(10) << "Checking " << var_name;
PADDLE_ENFORCE(!framework::HasNAN(var->Get<framework::LoDTensor>()),
"%s has NAN", var_name);
}
};
check_var_no_nan(cur_scope, inside_grad_name);
auto new_inside_name = cur_scope.Rename(inside_grad_name);
check_var_no_nan(cur_scope, new_inside_name);
auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_names[param_id], new_inside_name}}},
{{"Out", {pg_names[param_id]}}}, framework::AttributeMap{});
sum_op->Run(cur_scope, dev_place);
check_var_no_nan(cur_scope, pg_names[param_id]);
cur_scope.Rename(new_inside_name, inside_grad_name);
}
}
VLOG(1) << "Complete WhileOpGrad";
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册