From f5c2d175ae105e8938e8343068eff31db5745c19 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Thu, 28 Dec 2017 10:25:18 +0800 Subject: [PATCH] Refine --- paddle/framework/executor.cc | 9 +++++---- paddle/framework/tensor_impl.h | 13 +++++++++++-- paddle/framework/variable.h | 1 + paddle/operators/fill_constant_op.cc | 1 + paddle/operators/shrink_rnn_memory_op.cc | 5 +++-- paddle/operators/sum_op.h | 4 +--- paddle/operators/tensor_array_read_write_op.cc | 2 +- paddle/operators/while_op.cc | 13 +++++++++++++ 8 files changed, 36 insertions(+), 12 deletions(-) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index 9ee2ddb7c3e..fe9a42ace01 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -59,15 +59,16 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { static void CheckTensorNANOrInf(const std::string& name, const framework::Tensor& tensor) { - if (tensor.type().hash_code() != typeid(float).hash_code() && - tensor.type().hash_code() != typeid(double).hash_code()) { + if (tensor.memory_size() == 0) { return; } - if (tensor.memory_size() == 0) { + if (tensor.type().hash_code() != typeid(float).hash_code() && + tensor.type().hash_code() != typeid(double).hash_code()) { return; } PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name); - PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN", name); + PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN, %p", name, + &tensor); } void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 6c6f298edc1..0161ed8c475 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -134,8 +134,17 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) { #endif offset_ = 0; } - return reinterpret_cast(reinterpret_cast(holder_->ptr()) + - offset_); + void* buf = reinterpret_cast( + reinterpret_cast(holder_->ptr()) + offset_); + if (type.hash_code() == typeid(float).hash_code() || + type.hash_code() == typeid(double).hash_code()) { + float* tmp = (float*)(buf); + for (int64_t i = 0; i < numel(); ++i) { + tmp[i] = NAN; + } + } + + return buf; } inline void* Tensor::mutable_data(platform::Place place) { diff --git a/paddle/framework/variable.h b/paddle/framework/variable.h index e5a94759f92..3720393601a 100644 --- a/paddle/framework/variable.h +++ b/paddle/framework/variable.h @@ -35,6 +35,7 @@ class Variable { template T* GetMutable() { if (!IsType()) { + VLOG(10) << "Resetting " << *this->name_; holder_.reset(new PlaceholderImpl(new T())); } return static_cast(holder_->Ptr()); diff --git a/paddle/operators/fill_constant_op.cc b/paddle/operators/fill_constant_op.cc index dcd43a30c86..196c380c73e 100644 --- a/paddle/operators/fill_constant_op.cc +++ b/paddle/operators/fill_constant_op.cc @@ -51,6 +51,7 @@ class FillConstantOp : public framework::OperatorBase { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(dev_place); + VLOG(10) << "FillConstant to " << &out; math::set_constant(dev_ctx, &out, value); } }; diff --git a/paddle/operators/shrink_rnn_memory_op.cc b/paddle/operators/shrink_rnn_memory_op.cc index e5ef0740b6f..9ef473e7264 100644 --- a/paddle/operators/shrink_rnn_memory_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -116,9 +116,10 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { auto height = dout_tensor.dims()[0]; auto slice = dx_tensor.Slice(0, static_cast(height)); framework::CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx, &slice); - if (dx_tensor.dims()[0] < height) { + VLOG(10) << dx_tensor.dims()[0] << ", " << height; + if (dx_tensor.dims()[0] > height) { auto rest_tensor = dx_tensor.Slice( - static_cast(height), static_cast(dout_tensor.dims()[0])); + static_cast(height), static_cast(dx_tensor.dims()[0])); math::set_constant(dev_ctx, &rest_tensor, 0.0f); } } diff --git a/paddle/operators/sum_op.h b/paddle/operators/sum_op.h index eaa36aa1aea..d1277d3edd8 100644 --- a/paddle/operators/sum_op.h +++ b/paddle/operators/sum_op.h @@ -38,11 +38,9 @@ class SumKernel : public framework::OpKernel { if (out_var->IsType()) { auto *out = context.Output("Out"); - out->mutable_data(context.GetPlace()); - auto result = EigenVector::Flatten(*out); - if (!in_place) { + out->mutable_data(context.GetPlace()); math::SetConstant constant_functor; constant_functor(context.template device_context(), out, 0.0); diff --git a/paddle/operators/tensor_array_read_write_op.cc b/paddle/operators/tensor_array_read_write_op.cc index 53e38ec7033..d5ff3e3fce2 100644 --- a/paddle/operators/tensor_array_read_write_op.cc +++ b/paddle/operators/tensor_array_read_write_op.cc @@ -130,9 +130,9 @@ class ReadFromArrayOp : public ArrayOp { auto &x_array = x->Get(); auto *out = scope.FindVar(Output("Out")); PADDLE_ENFORCE(out != nullptr, "Out must be set"); - auto *out_tensor = out->GetMutable(); size_t offset = GetOffset(scope, place); if (offset < x_array.size()) { + auto *out_tensor = out->GetMutable(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(place); diff --git a/paddle/operators/while_op.cc b/paddle/operators/while_op.cc index 728ef607946..322270c8296 100644 --- a/paddle/operators/while_op.cc +++ b/paddle/operators/while_op.cc @@ -194,14 +194,27 @@ class WhileGradOp : public framework::OperatorBase { } } + auto check_var_no_nan = [](const framework::Scope &scope, + const std::string &var_name) { + auto *var = scope.FindVar(var_name); + if (var->IsType()) { + VLOG(10) << "Checking " << var_name; + PADDLE_ENFORCE(!framework::HasNAN(var->Get()), + "%s has NAN", var_name); + } + }; + check_var_no_nan(cur_scope, inside_grad_name); auto new_inside_name = cur_scope.Rename(inside_grad_name); + check_var_no_nan(cur_scope, new_inside_name); auto sum_op = framework::OpRegistry::CreateOp( "sum", {{"X", {pg_names[param_id], new_inside_name}}}, {{"Out", {pg_names[param_id]}}}, framework::AttributeMap{}); sum_op->Run(cur_scope, dev_place); + check_var_no_nan(cur_scope, pg_names[param_id]); cur_scope.Rename(new_inside_name, inside_grad_name); } } + VLOG(1) << "Complete WhileOpGrad"; } }; -- GitLab