未验证 提交 effc5559 编写于 作者: J Jiabin Yang 提交者: GitHub

test=develop, layz init Grad (#17653)

上级 33a791dd
...@@ -4,6 +4,5 @@ cc_library(tracer SRCS tracer.cc DEPS proto_desc device_context pybind profiler) ...@@ -4,6 +4,5 @@ cc_library(tracer SRCS tracer.cc DEPS proto_desc device_context pybind profiler)
cc_library(engine SRCS engine.cc) cc_library(engine SRCS engine.cc)
cc_library(imperative_profiler SRCS profiler.cc) cc_library(imperative_profiler SRCS profiler.cc)
cc_library(nccl_context SRCS nccl_context.cc DEPS device_context) cc_library(nccl_context SRCS nccl_context.cc DEPS device_context)
cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS nccl_context) cc_test(nccl_context_test SRCS nccl_context_test.cc DEPS nccl_context)
endif() endif()
...@@ -97,6 +97,13 @@ void AddTo(Variable* src, Variable* dst, platform::Place place) { ...@@ -97,6 +97,13 @@ void AddTo(Variable* src, Variable* dst, platform::Place place) {
boost::apply_visitor(func, place); boost::apply_visitor(func, place);
} }
void ZeroGrads(VarBase* vb, const platform::Place& place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
auto grad_t = vb->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(*dev_ctx, grad_t, 0.0);
}
void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) { void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) {
PADDLE_ENFORCE(bck_map->find(target) != bck_map->end(), PADDLE_ENFORCE(bck_map->find(target) != bck_map->end(),
"Can't find %s in backward grad map", target->Name()); "Can't find %s in backward grad map", target->Name());
...@@ -110,9 +117,9 @@ void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) { ...@@ -110,9 +117,9 @@ void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) {
for (auto& var_pair : current.second) { for (auto& var_pair : current.second) {
Variable* origin_grad = target->var_.get(); Variable* origin_grad = target->var_.get();
Variable* grad_to_add = var_pair.second->var_.get(); Variable* grad_to_add = var_pair.second->var_.get();
VLOG(2) << "add origin_grad: " << target->Name(); VLOG(10) << "add origin_grad: " << target->Name();
VLOG(2) << "added grad: " << var_pair.second->Name() VLOG(10) << "added grad: " << var_pair.second->Name()
<< " trace id is: " << var_pair.first; << " trace id is: " << var_pair.first;
AddTo(grad_to_add, origin_grad, current.first); AddTo(grad_to_add, origin_grad, current.first);
delete var_pair.second; delete var_pair.second;
var_pair.second = nullptr; var_pair.second = nullptr;
...@@ -127,7 +134,7 @@ class Autograd { ...@@ -127,7 +134,7 @@ class Autograd {
if (var->IsStopGradient()) { if (var->IsStopGradient()) {
return; return;
} }
VLOG(3) << "start autograd"; VLOG(2) << "start autograd";
BackwardSumMap bck_map; BackwardSumMap bck_map;
GradientRef grad_ref; GradientRef grad_ref;
std::deque<OpBase*> ready; std::deque<OpBase*> ready;
...@@ -195,7 +202,7 @@ class Autograd { ...@@ -195,7 +202,7 @@ class Autograd {
for (auto it : candidate->pre_ops_) { for (auto it : candidate->pre_ops_) {
for (OpBase* pre_op : it.second) { for (OpBase* pre_op : it.second) {
if (!pre_op) continue; if (!pre_op) continue;
VLOG(2) << "op dep " << candidate->Type() << " trace id " VLOG(9) << "op dep " << candidate->Type() << " trace id "
<< candidate->trace_id_ << " <---- " << it.first << " <---- " << candidate->trace_id_ << " <---- " << it.first << " <---- "
<< pre_op->Type() << " trace id " << pre_op->trace_id_; << pre_op->Type() << " trace id " << pre_op->trace_id_;
if (visited.find(pre_op) == visited.end()) { if (visited.find(pre_op) == visited.end()) {
...@@ -267,9 +274,11 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad( ...@@ -267,9 +274,11 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
for (const auto& it : grad_output_variable_map) { for (const auto& it : grad_output_variable_map) {
auto& outputs = tmp_grad_outputs[k][it.first]; auto& outputs = tmp_grad_outputs[k][it.first];
outputs.reserve(it.second.size()); outputs.reserve(it.second.size());
for (size_t i = 0; i < it.second.size(); ++i) { for (VarBase* origin_grad_var_base : it.second) {
VarBase* origin_grad_var_base = it.second[i]; if (!origin_grad_var_base->IsInitialize()) {
origin_grad_var_base->InitBuffer();
ZeroGrads(origin_grad_var_base, place_);
}
// Allocate a new variable // Allocate a new variable
VarBase* tmp_grad_var_base = new VarBase( VarBase* tmp_grad_var_base = new VarBase(
string::Sprintf("%s@IGrad", origin_grad_var_base->Name()), string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
...@@ -304,11 +313,15 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad( ...@@ -304,11 +313,15 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
for (const auto& it : grad_input_vars_[k]) { for (const auto& it : grad_input_vars_[k]) {
auto& grad_invars = grad_invars_map[it.first]; auto& grad_invars = grad_invars_map[it.first];
grad_invars.reserve(it.second.size()); grad_invars.reserve(it.second.size());
for (const VarBase* grad_inp : it.second) { for (VarBase* grad_inp : it.second) {
PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr", PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
grad_op_desc->Type(), grad_inp->Name()); grad_op_desc->Type(), grad_inp->Name());
if (!grad_inp->IsInitialize()) {
grad_invars.emplace_back(grad_inp->var_.get()); grad_inp->InitBuffer();
ZeroGrads(grad_inp, place_);
}
const VarBase* const_grad_inp = grad_inp;
grad_invars.emplace_back(const_grad_inp->var_.get());
} }
} }
...@@ -343,22 +356,23 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad( ...@@ -343,22 +356,23 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
// track outputs used by sum // track outputs used by sum
if (bck_stratedy.sorted_sum_gradient_) { if (bck_stratedy.sorted_sum_gradient_) {
#ifndef PADDLE_WITH_CUDA #ifndef PADDLE_WITH_CUDA
VLOG(2) << "origin_outputs is : " << origin_outputs[i]->Name() << " "; VLOG(10) << "origin_outputs is : " << origin_outputs[i]->Name()
VLOG(2) << origin_outputs[i] << " ";
->var_->GetMutable<framework::LoDTensor>() VLOG(10) << origin_outputs[i]
->data<float>()[0]; ->var_->GetMutable<framework::LoDTensor>()
VLOG(2) << "outputs is : " << outputs[i]->Name() << " "; ->data<float>()[0];
VLOG(2) << outputs[i] VLOG(10) << "outputs is : " << outputs[i]->Name() << " ";
->var_->GetMutable<framework::LoDTensor>() VLOG(10) << outputs[i]
->data<float>()[0]; ->var_->GetMutable<framework::LoDTensor>()
->data<float>()[0];
#endif #endif
if (bck_map->find(origin_outputs[i]) != bck_map->end()) { if (bck_map->find(origin_outputs[i]) != bck_map->end()) {
VLOG(2) << "add sub grad to " << origin_outputs[i]->Name(); VLOG(10) << "add sub grad to " << origin_outputs[i]->Name();
bck_map->at(origin_outputs[i]) bck_map->at(origin_outputs[i])
.second.emplace_back( .second.emplace_back(
std::pair<int, VarBase*>(this->trace_id_, outputs[i])); std::pair<int, VarBase*>(this->trace_id_, outputs[i]));
} else { } else {
VLOG(2) << "insert new map for " << origin_outputs[i]->Name(); VLOG(10) << "insert new map for " << origin_outputs[i]->Name();
std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>> std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>
tmp(place_, {std::make_pair(this->trace_id_, outputs[i])}); tmp(place_, {std::make_pair(this->trace_id_, outputs[i])});
bck_map->insert(std::make_pair(origin_outputs[i], tmp)); bck_map->insert(std::make_pair(origin_outputs[i], tmp));
...@@ -370,19 +384,19 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad( ...@@ -370,19 +384,19 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
PADDLE_ENFORCE(grad_ref->at(origin_outputs[i]) >= 1, PADDLE_ENFORCE(grad_ref->at(origin_outputs[i]) >= 1,
"Backward error when calculate grad reference"); "Backward error when calculate grad reference");
if (grad_ref->at(origin_outputs[i]) > 1) { if (grad_ref->at(origin_outputs[i]) > 1) {
VLOG(2) << "remove ref for " << origin_outputs[i]->Name(); VLOG(10) << "remove ref for " << origin_outputs[i]->Name();
grad_ref->at(origin_outputs[i])--; grad_ref->at(origin_outputs[i])--;
} else { } else {
VLOG(2) << "Add grad for: " << origin_outputs[i]->Name(); VLOG(10) << "Add grad for: " << origin_outputs[i]->Name();
AddGradBySort(bck_map, origin_outputs[i]); AddGradBySort(bck_map, origin_outputs[i]);
grad_ref->at(origin_outputs[i])--; grad_ref->at(origin_outputs[i])--;
} }
} else { } else {
framework::Variable* grad = outputs[i]->var_.get(); framework::Variable* grad = outputs[i]->var_.get();
framework::Variable* orig_grad = origin_outputs[i]->var_.get(); framework::Variable* orig_grad = origin_outputs[i]->var_.get();
VLOG(2) << "AddTo Called with orig_grad is: " VLOG(10) << "AddTo Called with orig_grad is: "
<< origin_outputs[i]->name_ << " Grad to be added is " << origin_outputs[i]->name_ << " Grad to be added is "
<< outputs[i]->name_; << outputs[i]->name_;
AddTo(grad, orig_grad, place_); AddTo(grad, orig_grad, place_);
delete outputs[i]; delete outputs[i];
} }
...@@ -413,6 +427,7 @@ void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) { ...@@ -413,6 +427,7 @@ void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
if (!pre_op_) return; if (!pre_op_) return;
platform::RecordEvent record_event("Imperative Backward"); platform::RecordEvent record_event("Imperative Backward");
VLOG(3) << "start backward"; VLOG(3) << "start backward";
grads_->InitBuffer();
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>(); auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant( operators::math::set_constant(
*(platform::DeviceContextPool::Instance().Get( *(platform::DeviceContextPool::Instance().Get(
......
...@@ -121,7 +121,7 @@ class VarBase { ...@@ -121,7 +121,7 @@ class VarBase {
: VarBase(name, var->Get<framework::LoDTensor>().type(), : VarBase(name, var->Get<framework::LoDTensor>().type(),
var->Get<framework::LoDTensor>().dims(), var->Get<framework::LoDTensor>().dims(),
var->Get<framework::LoDTensor>().place(), nullptr, grad, var->Get<framework::LoDTensor>().place(), nullptr, grad,
stop_gradient, false) { stop_gradient, false, true) {
var_ = std::move(var); var_ = std::move(var);
} }
...@@ -137,18 +137,27 @@ class VarBase { ...@@ -137,18 +137,27 @@ class VarBase {
const framework::DDim& shape, const platform::Place& place, const framework::DDim& shape, const platform::Place& place,
bool stop_gradient, bool persistable) bool stop_gradient, bool persistable)
: VarBase(name, dtype, shape, place, nullptr, nullptr, stop_gradient, : VarBase(name, dtype, shape, place, nullptr, nullptr, stop_gradient,
persistable) {} persistable, true) {}
// Grad used constructor
VarBase(const std::string& name, const framework::proto::VarType::Type dtype,
const std::vector<int64_t>& shape, const platform::Place& place,
bool stop_gradient, bool persistable, bool need_initialize)
: VarBase(name, dtype, framework::make_ddim(shape), place, nullptr,
nullptr, stop_gradient, persistable, need_initialize) {}
private: private:
// TODO(minqiyang): need support SelectedRows // TODO(minqiyang): need support SelectedRows
VarBase(const std::string& name, framework::proto::VarType::Type dtype, VarBase(const std::string& name, framework::proto::VarType::Type dtype,
const framework::DDim& shape, const platform::Place& place, const framework::DDim& shape, const platform::Place& place,
std::unique_ptr<framework::Variable> var, VarBase* grad, std::unique_ptr<framework::Variable> var, VarBase* grad,
bool stop_gradient, bool persistable) bool stop_gradient, bool persistable, bool need_initialize)
: name_(name), : name_(name),
type_(framework::proto::VarType::LOD_TENSOR), type_(framework::proto::VarType::LOD_TENSOR),
place_(place),
var_(std::move(var)), var_(std::move(var)),
grads_(grad), grads_(grad),
dtype_(dtype),
stop_gradient_(stop_gradient), stop_gradient_(stop_gradient),
persistable_(persistable), persistable_(persistable),
pre_op_(nullptr), pre_op_(nullptr),
...@@ -159,9 +168,17 @@ class VarBase { ...@@ -159,9 +168,17 @@ class VarBase {
} }
auto tensor = var_->GetMutable<framework::LoDTensor>(); auto tensor = var_->GetMutable<framework::LoDTensor>();
tensor->Resize(shape); tensor->Resize(shape);
tensor->mutable_data(place, dtype); if (need_initialize) {
VLOG(10) << "create varbase: " << name_ << " type: " << dtype tensor->mutable_data(place, dtype);
<< " place: " << place; is_initialized_ = true;
VLOG(2) << "initialized varbase: " << name_ << " type: " << dtype
<< " place: " << place;
} else {
is_initialized_ = false;
VLOG(2) << "not initialized varbase: " << name_;
}
VLOG(2) << "create varbase: " << name_ << " type: " << dtype
<< " place: " << place;
} }
public: public:
...@@ -173,10 +190,12 @@ class VarBase { ...@@ -173,10 +190,12 @@ class VarBase {
pre_op_ = nullptr; pre_op_ = nullptr;
pre_op_out_idx_ = -1; pre_op_out_idx_ = -1;
VLOG(2) << "destruct varbase: " << name_;
} }
inline void SetName(const std::string& name) { name_ = name; } inline void SetName(const std::string& name) { name_ = name; }
inline std::string Name() const { return name_; } inline std::string Name() const { return name_; }
inline bool IsInitialize() const { return is_initialized_; }
inline std::vector<int64_t> Shape() const { inline std::vector<int64_t> Shape() const {
if (var_->IsInitialized()) { if (var_->IsInitialized()) {
...@@ -211,7 +230,7 @@ class VarBase { ...@@ -211,7 +230,7 @@ class VarBase {
inline void SetPersistable(bool persistable) { persistable_ = persistable; } inline void SetPersistable(bool persistable) { persistable_ = persistable; }
inline bool IsPersistable() const { return persistable_; } inline bool IsPersistable() const { return persistable_; }
inline platform::Place GetPlace() { return place_; }
inline OpBase* PreOp() const { return pre_op_; } inline OpBase* PreOp() const { return pre_op_; }
inline int PreOpOutIdx() const { return pre_op_out_idx_; } inline int PreOpOutIdx() const { return pre_op_out_idx_; }
...@@ -225,6 +244,17 @@ class VarBase { ...@@ -225,6 +244,17 @@ class VarBase {
} }
} }
void InitBuffer() {
if (!is_initialized_) {
var_->GetMutable<framework::LoDTensor>()->mutable_data(place_, dtype_);
is_initialized_ = true;
VLOG(2) << "initialized varbase: " << name_ << " type: " << dtype_
<< " place: " << place_;
} else {
VLOG(2) << "var: " << name_ << " has already been initialized ";
}
}
void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name, void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name,
int pre_op_out_idx, bool pre_op_stop_gradient) { int pre_op_out_idx, bool pre_op_stop_gradient) {
pre_op_ = pre_op; pre_op_ = pre_op;
...@@ -263,9 +293,10 @@ class VarBase { ...@@ -263,9 +293,10 @@ class VarBase {
VarBase* grads_; VarBase* grads_;
private: private:
framework::proto::VarType::Type dtype_;
bool stop_gradient_; bool stop_gradient_;
bool persistable_; bool persistable_;
bool is_initialized_;
OpBase* pre_op_; OpBase* pre_op_;
std::string pre_op_out_name_; std::string pre_op_out_name_;
int pre_op_out_idx_; int pre_op_out_idx_;
......
...@@ -46,7 +46,7 @@ void CreateGradOp(const framework::OpDesc& op_desc, ...@@ -46,7 +46,7 @@ void CreateGradOp(const framework::OpDesc& op_desc,
} }
} }
void InitGrad(VarBase* var, platform::DeviceContext* dev_ctx) { void CreateNoBuffuerGrad(VarBase* var, platform::DeviceContext* dev_ctx) {
PADDLE_ENFORCE_NOT_NULL(var, "Could not get valid var base"); PADDLE_ENFORCE_NOT_NULL(var, "Could not get valid var base");
PADDLE_ENFORCE_NOT_NULL(dev_ctx, PADDLE_ENFORCE_NOT_NULL(dev_ctx,
"Could not get valid device from forward op"); "Could not get valid device from forward op");
...@@ -55,9 +55,7 @@ void InitGrad(VarBase* var, platform::DeviceContext* dev_ctx) { ...@@ -55,9 +55,7 @@ void InitGrad(VarBase* var, platform::DeviceContext* dev_ctx) {
auto& var_t = var->var_->Get<framework::LoDTensor>(); auto& var_t = var->var_->Get<framework::LoDTensor>();
var->grads_ = new VarBase(var->GradName(), framework::proto::VarType::FP32, var->grads_ = new VarBase(var->GradName(), framework::proto::VarType::FP32,
framework::vectorize(var_t.dims()), framework::vectorize(var_t.dims()),
dev_ctx->GetPlace(), true, false); dev_ctx->GetPlace(), true, false, false);
auto grad_t = var->grads_->var_->GetMutable<framework::LoDTensor>();
operators::math::set_constant(*dev_ctx, grad_t, 0.0);
} }
} }
...@@ -261,7 +259,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -261,7 +259,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
grad_in_vars.emplace_back(fwd_var_it->second); grad_in_vars.emplace_back(fwd_var_it->second);
} else { } else {
VarBase* var = current_vars_map[var_it->second]; VarBase* var = current_vars_map[var_it->second];
InitGrad(var, prepared_op.GetDeviceContext()); CreateNoBuffuerGrad(var, prepared_op.GetDeviceContext());
// Douts. // Douts.
grad_in_vars.emplace_back(var->grads_); grad_in_vars.emplace_back(var->grads_);
} }
...@@ -279,7 +277,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -279,7 +277,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
"operator %s's stop gradient be True", "operator %s's stop gradient be True",
op->Type()); op->Type());
VarBase* var = current_vars_map[var_it->second]; VarBase* var = current_vars_map[var_it->second];
InitGrad(var, prepared_op.GetDeviceContext()); CreateNoBuffuerGrad(var, prepared_op.GetDeviceContext());
grad_out_vars.push_back(var->grads_); grad_out_vars.push_back(var->grads_);
VLOG(3) << "grads output var name: " << var->name_; VLOG(3) << "grads output var name: " << var->name_;
} }
...@@ -289,6 +287,5 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -289,6 +287,5 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
return vars_saved_for_backward; return vars_saved_for_backward;
} }
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -147,7 +147,8 @@ class Layer(core.Layer): ...@@ -147,7 +147,8 @@ class Layer(core.Layer):
def clear_gradients(self): def clear_gradients(self):
for p in self.parameters(): for p in self.parameters():
p.clear_gradient() if p.trainable:
p.clear_gradient()
def build_once(self, *args): def build_once(self, *args):
pass pass
......
...@@ -101,11 +101,11 @@ class TestDygraphGNN(unittest.TestCase): ...@@ -101,11 +101,11 @@ class TestDygraphGNN(unittest.TestCase):
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0)) ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
exe.run(startup) exe.run(startup)
static_loss = exe.run(feed={ static_loss = exe.run(feed={
'features': np.zeros( 'features': np.ones(
[1, 100, 50], dtype=np.float32), [1, 100, 50], dtype=np.float32),
'adj': np.zeros( 'adj': np.ones(
[1, 100, 100], dtype=np.float32), [1, 100, 100], dtype=np.float32),
'labels': np.zeros( 'labels': np.ones(
[100, 1], dtype=np.int64) [100, 1], dtype=np.int64)
}, },
fetch_list=[loss])[0] fetch_list=[loss])[0]
...@@ -117,10 +117,10 @@ class TestDygraphGNN(unittest.TestCase): ...@@ -117,10 +117,10 @@ class TestDygraphGNN(unittest.TestCase):
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
features = np.zeros([1, 100, 50], dtype=np.float32) features = np.ones([1, 100, 50], dtype=np.float32)
# Use selected rows when it's supported. # Use selected rows when it's supported.
adj = np.zeros([1, 100, 100], dtype=np.float32) adj = np.ones([1, 100, 100], dtype=np.float32)
labels = np.zeros([100, 1], dtype=np.int64) labels = np.ones([100, 1], dtype=np.int64)
model = GCN('test_gcn', 50) model = GCN('test_gcn', 50)
logits = model(to_variable(features), to_variable(adj)) logits = model(to_variable(features), to_variable(adj))
...@@ -130,17 +130,20 @@ class TestDygraphGNN(unittest.TestCase): ...@@ -130,17 +130,20 @@ class TestDygraphGNN(unittest.TestCase):
loss = fluid.layers.softmax_with_cross_entropy(logits, loss = fluid.layers.softmax_with_cross_entropy(logits,
to_variable(labels)) to_variable(labels))
loss = fluid.layers.reduce_sum(loss) loss = fluid.layers.reduce_sum(loss)
loss.backward()
adam = AdamOptimizer(learning_rate=1e-3) adam = AdamOptimizer(learning_rate=1e-3)
adam.minimize(loss) adam.minimize(loss)
model.clear_gradients()
with fluid.dygraph.guard(): with fluid.dygraph.guard():
fluid.default_startup_program().random_seed = seed fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
features2 = np.zeros([1, 100, 50], dtype=np.float32) features2 = np.ones([1, 100, 50], dtype=np.float32)
# Use selected rows when it's supported. # Use selected rows when it's supported.
adj2 = np.zeros([1, 100, 100], dtype=np.float32) adj2 = np.ones([1, 100, 100], dtype=np.float32)
labels2 = np.zeros([100, 1], dtype=np.int64) labels2 = np.ones([100, 1], dtype=np.int64)
model2 = GCN('test_gcn', 50) model2 = GCN('test_gcn', 50)
logits2 = model2(to_variable(features2), to_variable(adj2)) logits2 = model2(to_variable(features2), to_variable(adj2))
...@@ -150,8 +153,10 @@ class TestDygraphGNN(unittest.TestCase): ...@@ -150,8 +153,10 @@ class TestDygraphGNN(unittest.TestCase):
loss2 = fluid.layers.softmax_with_cross_entropy( loss2 = fluid.layers.softmax_with_cross_entropy(
logits2, to_variable(labels2)) logits2, to_variable(labels2))
loss2 = fluid.layers.reduce_sum(loss2) loss2 = fluid.layers.reduce_sum(loss2)
loss2.backward()
adam2 = AdamOptimizer(learning_rate=1e-3) adam2 = AdamOptimizer(learning_rate=1e-3)
adam2.minimize(loss2) adam2.minimize(loss2)
model2.clear_gradients()
self.assertEqual(static_loss, loss.numpy()) self.assertEqual(static_loss, loss.numpy())
self.assertTrue(np.allclose(static_weight, model.gc.weight.numpy())) self.assertTrue(np.allclose(static_weight, model.gc.weight.numpy()))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册