提交 f5c2d175 编写于 作者: Y Yang Yu

Refine

上级 6f5e64af
...@@ -59,15 +59,16 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { ...@@ -59,15 +59,16 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
static void CheckTensorNANOrInf(const std::string& name, static void CheckTensorNANOrInf(const std::string& name,
const framework::Tensor& tensor) { const framework::Tensor& tensor) {
if (tensor.type().hash_code() != typeid(float).hash_code() && if (tensor.memory_size() == 0) {
tensor.type().hash_code() != typeid(double).hash_code()) {
return; return;
} }
if (tensor.memory_size() == 0) { if (tensor.type().hash_code() != typeid(float).hash_code() &&
tensor.type().hash_code() != typeid(double).hash_code()) {
return; return;
} }
PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name); PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name);
PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN", name); PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN, %p", name,
&tensor);
} }
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
......
...@@ -134,8 +134,17 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) { ...@@ -134,8 +134,17 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) {
#endif #endif
offset_ = 0; offset_ = 0;
} }
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(holder_->ptr()) + void* buf = reinterpret_cast<void*>(
offset_); reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
if (type.hash_code() == typeid(float).hash_code() ||
type.hash_code() == typeid(double).hash_code()) {
float* tmp = (float*)(buf);
for (int64_t i = 0; i < numel(); ++i) {
tmp[i] = NAN;
}
}
return buf;
} }
inline void* Tensor::mutable_data(platform::Place place) { inline void* Tensor::mutable_data(platform::Place place) {
......
...@@ -35,6 +35,7 @@ class Variable { ...@@ -35,6 +35,7 @@ class Variable {
template <typename T> template <typename T>
T* GetMutable() { T* GetMutable() {
if (!IsType<T>()) { if (!IsType<T>()) {
VLOG(10) << "Resetting " << *this->name_;
holder_.reset(new PlaceholderImpl<T>(new T())); holder_.reset(new PlaceholderImpl<T>(new T()));
} }
return static_cast<T*>(holder_->Ptr()); return static_cast<T*>(holder_->Ptr());
......
...@@ -51,6 +51,7 @@ class FillConstantOp : public framework::OperatorBase { ...@@ -51,6 +51,7 @@ class FillConstantOp : public framework::OperatorBase {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(dev_place); auto &dev_ctx = *pool.Get(dev_place);
VLOG(10) << "FillConstant to " << &out;
math::set_constant(dev_ctx, &out, value); math::set_constant(dev_ctx, &out, value);
} }
}; };
......
...@@ -116,9 +116,10 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { ...@@ -116,9 +116,10 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
auto height = dout_tensor.dims()[0]; auto height = dout_tensor.dims()[0];
auto slice = dx_tensor.Slice(0, static_cast<int>(height)); auto slice = dx_tensor.Slice(0, static_cast<int>(height));
framework::CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx, &slice); framework::CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx, &slice);
if (dx_tensor.dims()[0] < height) { VLOG(10) << dx_tensor.dims()[0] << ", " << height;
if (dx_tensor.dims()[0] > height) {
auto rest_tensor = dx_tensor.Slice( auto rest_tensor = dx_tensor.Slice(
static_cast<int>(height), static_cast<int>(dout_tensor.dims()[0])); static_cast<int>(height), static_cast<int>(dx_tensor.dims()[0]));
math::set_constant(dev_ctx, &rest_tensor, 0.0f); math::set_constant(dev_ctx, &rest_tensor, 0.0f);
} }
} }
......
...@@ -38,11 +38,9 @@ class SumKernel : public framework::OpKernel<T> { ...@@ -38,11 +38,9 @@ class SumKernel : public framework::OpKernel<T> {
if (out_var->IsType<framework::LoDTensor>()) { if (out_var->IsType<framework::LoDTensor>()) {
auto *out = context.Output<Tensor>("Out"); auto *out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto result = EigenVector<T>::Flatten(*out); auto result = EigenVector<T>::Flatten(*out);
if (!in_place) { if (!in_place) {
out->mutable_data<T>(context.GetPlace());
math::SetConstant<DeviceContext, T> constant_functor; math::SetConstant<DeviceContext, T> constant_functor;
constant_functor(context.template device_context<DeviceContext>(), out, constant_functor(context.template device_context<DeviceContext>(), out,
0.0); 0.0);
......
...@@ -130,9 +130,9 @@ class ReadFromArrayOp : public ArrayOp { ...@@ -130,9 +130,9 @@ class ReadFromArrayOp : public ArrayOp {
auto &x_array = x->Get<framework::LoDTensorArray>(); auto &x_array = x->Get<framework::LoDTensorArray>();
auto *out = scope.FindVar(Output("Out")); auto *out = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out != nullptr, "Out must be set"); PADDLE_ENFORCE(out != nullptr, "Out must be set");
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
size_t offset = GetOffset(scope, place); size_t offset = GetOffset(scope, place);
if (offset < x_array.size()) { if (offset < x_array.size()) {
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
platform::DeviceContextPool &pool = platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
......
...@@ -194,14 +194,27 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -194,14 +194,27 @@ class WhileGradOp : public framework::OperatorBase {
} }
} }
auto check_var_no_nan = [](const framework::Scope &scope,
const std::string &var_name) {
auto *var = scope.FindVar(var_name);
if (var->IsType<LoDTensor>()) {
VLOG(10) << "Checking " << var_name;
PADDLE_ENFORCE(!framework::HasNAN(var->Get<framework::LoDTensor>()),
"%s has NAN", var_name);
}
};
check_var_no_nan(cur_scope, inside_grad_name);
auto new_inside_name = cur_scope.Rename(inside_grad_name); auto new_inside_name = cur_scope.Rename(inside_grad_name);
check_var_no_nan(cur_scope, new_inside_name);
auto sum_op = framework::OpRegistry::CreateOp( auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_names[param_id], new_inside_name}}}, "sum", {{"X", {pg_names[param_id], new_inside_name}}},
{{"Out", {pg_names[param_id]}}}, framework::AttributeMap{}); {{"Out", {pg_names[param_id]}}}, framework::AttributeMap{});
sum_op->Run(cur_scope, dev_place); sum_op->Run(cur_scope, dev_place);
check_var_no_nan(cur_scope, pg_names[param_id]);
cur_scope.Rename(new_inside_name, inside_grad_name); cur_scope.Rename(new_inside_name, inside_grad_name);
} }
} }
VLOG(1) << "Complete WhileOpGrad";
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册