未验证 提交 7bb010e7 编写于 作者: Y Yinggang Wang 提交者: GitHub

Support assign copy interface (#6228)

* feat(Tensor): support assign copy interface

* feat(Module): use assign copy

* test(Tensor): add test for tensor assign_copy

* fix(Optim): create buffer data in step()

* rename assign_copy to set_data

* fix(Parameter): fix Parameter set_data bug

* refine log to log_once in data()

* test(ConsistentTensor): add set_data test

* fix(Module): fix Module.to_consistent bug

* auto format by CI
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>
上级 e1a16561
......@@ -151,6 +151,9 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
throw std::runtime_error("You can only change gradient of leaf tensors.");
}
})
.def_property(
"data", [](Tensor& t) { return t.data().GetPtrOrThrow(); },
[](Tensor& t, const std::shared_ptr<Tensor>& other) { t.set_data(other).GetOrThrow(); })
.def("storage_offset", [](const Tensor& t) { return t.storage_offset().GetOrThrow(); })
.def("stride",
[](const Tensor& t) {
......@@ -185,7 +188,6 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
// local tensor only
.def_property_readonly("_tensor_buffer_shapes_and_dtypes", &GetTensorBufferShapesAndDTypes)
.def_property_readonly("device", &TensorGetDevice)
.def_property_readonly("data", &Tensor::data)
.def("consistent_id",
[](const one::Tensor& tensor) -> int64_t {
return static_cast<uint64_t>(tensor.transport_token().GetOrThrow());
......
......@@ -381,6 +381,15 @@ std::string GetFormatedSerializedError(const std::shared_ptr<cfg::ErrorProto>& e
": "
#define RETURN_ERROR_WITH_BUG_PROMPT() OF_RUNTIME_ERROR() << kOfBugIssueUploadPrompt
#define OF_LOG_ONCE(x) \
{ \
static bool warned = false; \
if (!warned) { \
warned = true; \
x; \
} \
}
#define OF_COMPLIE_OPTION_ERROR() \
return Error::CompileOptionWrongError().AddStackFrame(__FILE__, __LINE__, __FUNCTION__) \
<< " Compile option wrong: "
......
......@@ -56,11 +56,6 @@ Maybe<MirroredTensor> StaticZerosTensor::AsMirroredTensor() {
bool MirroredTensor::is_cuda() const { return CHECK_JUST(device())->type() == "cuda"; }
std::shared_ptr<Tensor> MirroredTensor::data() const {
std::shared_ptr<MirroredTensor> t = std::make_shared<MirroredTensor>(impl_);
return t;
}
Maybe<Tensor> MirroredTensor::detach() const {
std::shared_ptr<Tensor> tensor = std::make_shared<MirroredTensor>(JUST(impl_->detach()));
return tensor;
......@@ -105,11 +100,6 @@ bool ConsistentTensor::is_cuda() const {
return CHECK_JUST(parallel_desc())->device_type() == DeviceType::kGPU;
}
std::shared_ptr<Tensor> ConsistentTensor::data() const {
std::shared_ptr<ConsistentTensor> t = std::make_shared<ConsistentTensor>(impl_);
return t;
}
Maybe<Tensor> ConsistentTensor::detach() const {
std::shared_ptr<Tensor> t = std::make_shared<ConsistentTensor>(impl_);
return t;
......
......@@ -16,6 +16,7 @@ limitations under the License.
#ifndef ONEFLOW_CORE_FRAMEWORK_TENSOR_H_
#define ONEFLOW_CORE_FRAMEWORK_TENSOR_H_
#include <memory>
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/data_type.cfg.h"
#include "oneflow/core/common/shape_view.h"
......@@ -62,6 +63,7 @@ class Tensor {
virtual bool is_lazy() const = 0;
virtual bool is_eager() const { return !is_lazy(); }
virtual const TensorMeta& tensor_meta() const = 0;
virtual Maybe<Tensor> data() = 0;
virtual Maybe<Symbol<ConsistentTensorMeta>> consistent_tensor_meta() const { OF_UNIMPLEMENTED(); }
// Getters valid only for EagerMirroredTensor
......@@ -89,7 +91,6 @@ class Tensor {
virtual Maybe<TensorArg> current_grad() const = 0;
virtual Maybe<Tensor> detach() const = 0;
virtual Maybe<Tensor> clone() const = 0;
virtual std::shared_ptr<Tensor> data() const = 0;
// Setters for autograd
virtual void set_requires_grad(bool requires_grad) = 0;
......@@ -104,6 +105,7 @@ class Tensor {
virtual void set_autograd_meta(const std::shared_ptr<AutogradMeta>& autograd_meta) = 0;
virtual user_op::TensorDesc* mut_tensor_meta() = 0;
virtual Maybe<void> set_data(const std::shared_ptr<Tensor>& other) = 0;
virtual Maybe<MirroredTensor> AsMirroredTensor() = 0;
virtual Maybe<ConsistentTensor> AsConsistentTensor() = 0;
......@@ -119,110 +121,114 @@ class StaticZerosTensor final : public Tensor {
return std::shared_ptr<StaticZerosTensor>(new StaticZerosTensor(shape, dtype, device));
}
// Getters
const std::shared_ptr<const Shape>& shape() const { return shape_; }
Symbol<DType> dtype() const { return CHECK_JUST(DType::Get(dtype_)); }
Maybe<TransportToken> transport_token() const { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<Symbol<cfg::NdSbp>> nd_sbp() const { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<Symbol<ParallelDesc>> parallel_desc() const { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<Symbol<Device>> device() const { return device_; }
Maybe<Symbol<Device>*> mut_device() { RETURN_ERROR_WITH_BUG_PROMPT(); }
bool is_cuda() const {
const std::shared_ptr<const Shape>& shape() const override { return shape_; }
Symbol<DType> dtype() const override { return CHECK_JUST(DType::Get(dtype_)); }
Maybe<TransportToken> transport_token() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<Symbol<cfg::NdSbp>> nd_sbp() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<Symbol<ParallelDesc>> parallel_desc() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<Symbol<Device>> device() const override { return device_; }
Maybe<Symbol<Device>*> mut_device() override { RETURN_ERROR_WITH_BUG_PROMPT(); }
bool is_cuda() const override {
PRINT_BUG_PROMPT_AND_ABORT();
return false;
}
bool is_consistent() const { return false; }
bool is_local() const { return !is_consistent(); }
bool is_lazy() const {
bool is_consistent() const override { return false; }
bool is_local() const override { return !is_consistent(); }
bool is_lazy() const override {
PRINT_BUG_PROMPT_AND_ABORT();
return false;
}
bool is_eager() const { return !is_lazy(); }
const TensorMeta& tensor_meta() const {
bool is_eager() const override { return !is_lazy(); }
const TensorMeta& tensor_meta() const override {
PRINT_BUG_PROMPT_AND_ABORT();
return *(TensorMeta*)nullptr;
}
Maybe<Symbol<ConsistentTensorMeta>> consistent_tensor_meta() const {
Maybe<Tensor> data() override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<Symbol<ConsistentTensorMeta>> consistent_tensor_meta() const override {
RETURN_ERROR_WITH_BUG_PROMPT();
}
// Getters valid only for EagerMirroredTensor
Maybe<EagerMirroredTensorImpl*> mut_eager_mirrored_tensor_impl() {
Maybe<EagerMirroredTensorImpl*> mut_eager_mirrored_tensor_impl() override {
RETURN_ERROR_WITH_BUG_PROMPT();
}
Maybe<vm::EagerBlobObject> eager_blob_object() const { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<LocalDepObject*> compute_local_dep_object() const { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<bool> has_eager_blob_object() const { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<TensorStorage> tensor_storage() const { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<const Stride> stride() const { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<int64_t> storage_offset() const { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<vm::EagerBlobObject> eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<LocalDepObject*> compute_local_dep_object() const override {
RETURN_ERROR_WITH_BUG_PROMPT();
}
Maybe<bool> has_eager_blob_object() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<TensorStorage> tensor_storage() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<const Stride> stride() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<int64_t> storage_offset() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
// Getters/Setters valid only for EagerConsistentTensor
Maybe<const Optional<Symbol<cfg::NdSbp>>&> consumer_nd_sbp_constraint() const {
Maybe<const Optional<Symbol<cfg::NdSbp>>&> consumer_nd_sbp_constraint() const override {
RETURN_ERROR_WITH_BUG_PROMPT();
}
Maybe<MirroredTensor> cur_rank_phy_tensor() const { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<void> set_consumer_nd_sbp_constraint(Symbol<cfg::NdSbp> val) {
Maybe<MirroredTensor> cur_rank_phy_tensor() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<void> set_consumer_nd_sbp_constraint(Symbol<cfg::NdSbp> val) override {
RETURN_ERROR_WITH_BUG_PROMPT();
}
// Getters for autograd
bool requires_grad() const {
bool requires_grad() const override {
PRINT_BUG_PROMPT_AND_ABORT();
return false;
}
bool is_leaf() const {
bool is_leaf() const override {
PRINT_BUG_PROMPT_AND_ABORT();
return false;
}
bool retain_grad() const {
bool retain_grad() const override {
PRINT_BUG_PROMPT_AND_ABORT();
return false;
}
std::shared_ptr<const FunctionNode> grad_fn_node() const {
PRINT_BUG_PROMPT_AND_ABORT();
return nullptr;
}
Maybe<Tensor> acc_grad() const { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<TensorArg> current_grad() const { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<Tensor> detach() const { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<Tensor> clone() const { RETURN_ERROR_WITH_BUG_PROMPT(); }
std::shared_ptr<Tensor> data() const {
std::shared_ptr<const FunctionNode> grad_fn_node() const override {
PRINT_BUG_PROMPT_AND_ABORT();
return nullptr;
}
Maybe<Tensor> acc_grad() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<TensorArg> current_grad() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<Tensor> detach() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<Tensor> clone() const override { RETURN_ERROR_WITH_BUG_PROMPT(); }
// Setters for autograd
void set_requires_grad(bool requires_grad) { PRINT_BUG_PROMPT_AND_ABORT(); }
Maybe<void> set_retain_grad(bool retain_grad) { RETURN_ERROR_WITH_BUG_PROMPT(); }
void set_grad_fn_node(const std::shared_ptr<FunctionNode>& grad_fn_node) {
void set_requires_grad(bool requires_grad) override { PRINT_BUG_PROMPT_AND_ABORT(); }
Maybe<void> set_retain_grad(bool retain_grad) override { RETURN_ERROR_WITH_BUG_PROMPT(); }
void set_grad_fn_node(const std::shared_ptr<FunctionNode>& grad_fn_node) override {
PRINT_BUG_PROMPT_AND_ABORT();
}
const std::shared_ptr<FunctionNode>& mut_grad_fn_node() {
const std::shared_ptr<FunctionNode>& mut_grad_fn_node() override {
PRINT_BUG_PROMPT_AND_ABORT();
return *(std::shared_ptr<FunctionNode>*)nullptr;
}
Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad) { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<Tensor> mut_acc_grad() { RETURN_ERROR_WITH_BUG_PROMPT(); }
void set_is_leaf(bool is_leaf) { PRINT_BUG_PROMPT_AND_ABORT(); }
std::shared_ptr<AutogradMeta> mut_autograd_meta() {
Maybe<void> set_acc_grad(const std::shared_ptr<Tensor>& grad) override {
RETURN_ERROR_WITH_BUG_PROMPT();
}
Maybe<Tensor> mut_acc_grad() override { RETURN_ERROR_WITH_BUG_PROMPT(); }
void set_is_leaf(bool is_leaf) override { PRINT_BUG_PROMPT_AND_ABORT(); }
std::shared_ptr<AutogradMeta> mut_autograd_meta() override {
PRINT_BUG_PROMPT_AND_ABORT();
return nullptr;
}
bool has_autograd_meta() const {
bool has_autograd_meta() const override {
PRINT_BUG_PROMPT_AND_ABORT();
return false;
}
void set_autograd_meta(const std::shared_ptr<AutogradMeta>& autograd_meta) {
void set_autograd_meta(const std::shared_ptr<AutogradMeta>& autograd_meta) override {
PRINT_BUG_PROMPT_AND_ABORT();
}
user_op::TensorDesc* mut_tensor_meta() {
user_op::TensorDesc* mut_tensor_meta() override {
PRINT_BUG_PROMPT_AND_ABORT();
return nullptr;
}
Maybe<void> set_data(const std::shared_ptr<Tensor>& other) override {
RETURN_ERROR_WITH_BUG_PROMPT();
}
Maybe<MirroredTensor> AsMirroredTensor();
Maybe<ConsistentTensor> AsConsistentTensor() { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<MirroredTensor> AsMirroredTensor() override;
Maybe<ConsistentTensor> AsConsistentTensor() override { RETURN_ERROR_WITH_BUG_PROMPT(); }
private:
StaticZerosTensor(const std::shared_ptr<const Shape>& shape, DataType dtype,
......@@ -281,6 +287,7 @@ class Parameter final : public TensorIf<Parameter> {
Maybe<Symbol<ConsistentTensorMeta>> consistent_tensor_meta() const override {
return tensor_->consistent_tensor_meta();
}
Maybe<Tensor> data() override { return tensor_; }
Maybe<EagerMirroredTensorImpl*> mut_eager_mirrored_tensor_impl() override {
return tensor_->mut_eager_mirrored_tensor_impl();
......@@ -314,7 +321,6 @@ class Parameter final : public TensorIf<Parameter> {
Maybe<TensorArg> current_grad() const override { return tensor_->current_grad(); }
Maybe<Tensor> detach() const override { return tensor_->detach(); }
Maybe<Tensor> clone() const override { return tensor_->clone(); }
std::shared_ptr<Tensor> data() const override { return tensor_->data(); }
void set_requires_grad(bool requires_grad) override {
return tensor_->set_requires_grad(requires_grad);
......@@ -336,12 +342,22 @@ class Parameter final : public TensorIf<Parameter> {
}
user_op::TensorDesc* mut_tensor_meta() override { return tensor_->mut_tensor_meta(); }
Maybe<void> set_data(const std::shared_ptr<Tensor>& other) override {
std::shared_ptr<Tensor> tensor = other;
while (auto parameter = std::dynamic_pointer_cast<Parameter>(tensor)) {
tensor = parameter->tensor_;
}
CHECK_OR_RETURN(is_local() == tensor->is_local() && is_eager() == tensor->is_eager())
<< "You can't assign copy between tensors with different type";
this->tensor_ = std::move(tensor);
return Maybe<void>::Ok();
}
Maybe<MirroredTensor> AsMirroredTensor() override {
if (const auto& mirrored_tensor = std::dynamic_pointer_cast<MirroredTensor>(tensor_)) {
return mirrored_tensor;
}
OF_RUNTIME_ERROR() << "Parameter Tensor has no AsMirroredTensor property";
RETURN_ERROR_WITH_BUG_PROMPT();
}
Maybe<ConsistentTensor> AsConsistentTensor() override {
......@@ -389,8 +405,11 @@ class MirroredTensor final : public TensorIf<MirroredTensor>,
bool is_lazy() const override { return impl_->is_lazy(); }
bool is_consistent() const override { return false; }
bool is_cuda() const override;
std::shared_ptr<Tensor> data() const override;
const TensorMeta& tensor_meta() const override { return *impl_->tensor_meta(); }
Maybe<Tensor> data() override {
OF_LOG_ONCE(LOG(WARNING) << "You shouldn't call `.data` for a LocalTensor.");
return std::static_pointer_cast<Tensor>(shared_from_this());
}
// Getters valid only for EagerMirroredTensor
Maybe<vm::EagerBlobObject> eager_blob_object() const override {
......@@ -439,6 +458,13 @@ class MirroredTensor final : public TensorIf<MirroredTensor>,
return impl_->mut_eager_mirrored_tensor_impl();
}
user_op::TensorDesc* mut_tensor_meta() override { return impl_->mut_tensor_meta(); }
Maybe<void> set_data(const std::shared_ptr<Tensor>& other) override {
const auto& mirrored_tensor = std::dynamic_pointer_cast<MirroredTensor>(other);
CHECK_NOTNULL_OR_RETURN(mirrored_tensor);
impl_ = mirrored_tensor->impl_;
grad_fn_node_ = mirrored_tensor->grad_fn_node_;
return Maybe<void>::Ok();
}
Maybe<MirroredTensor> AsMirroredTensor() override { return shared_from_this(); }
Maybe<ConsistentTensor> AsConsistentTensor() override { RETURN_ERROR_WITH_BUG_PROMPT(); }
......@@ -477,7 +503,10 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor>,
return impl_->cur_rank_phy_tensor();
}
bool is_cuda() const override;
std::shared_ptr<Tensor> data() const override;
Maybe<Tensor> data() override {
OF_LOG_ONCE(LOG(WARNING) << "You shouldn't call `.data` for a ConsistentTensor.");
return std::static_pointer_cast<Tensor>(shared_from_this());
}
// Getters valid only for EagerMirroredTensor
Maybe<vm::EagerBlobObject> eager_blob_object() const override {
......@@ -535,6 +564,13 @@ class ConsistentTensor final : public TensorIf<ConsistentTensor>,
}
user_op::TensorDesc* mut_tensor_meta() override { return impl_->mut_tensor_meta(); }
Maybe<void> set_data(const std::shared_ptr<Tensor>& other) override {
const auto& consistent_tensor = std::dynamic_pointer_cast<ConsistentTensor>(other);
CHECK_NOTNULL_OR_RETURN(consistent_tensor);
impl_ = consistent_tensor->impl_;
grad_fn_node_ = consistent_tensor->grad_fn_node_;
return Maybe<void>::Ok();
}
Maybe<MirroredTensor> AsMirroredTensor() override { RETURN_ERROR_WITH_BUG_PROMPT(); }
Maybe<ConsistentTensor> AsConsistentTensor() override { return shared_from_this(); }
......
......@@ -466,20 +466,32 @@ class Module(object):
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def can_use_assign_copy(tensor, tensor_applied):
return tensor.is_local == tensor_applied.is_local
for (key, param) in self._parameters.items():
if param is not None:
assert isinstance(param, Parameter)
assert param.is_leaf
if param is None:
continue
assert isinstance(param, Parameter)
assert param.is_leaf
with flow.no_grad():
param_applied = fn(param)
param_applied.requires_grad = param.requires_grad
if param.grad is not None:
assert param.grad.is_leaf
with flow.no_grad():
param_applied = fn(param)
grad_applied = fn(param.grad)
grad_applied.requires_grad = param.grad.requires_grad
param_applied.grad = grad_applied
if can_use_assign_copy(param_applied, param):
self._parameters[key].data = param_applied
else:
self._parameters[key] = Parameter(param_applied, param.requires_grad)
if param.grad is not None:
assert param.grad.is_leaf
with flow.no_grad():
grad_applied = fn(param.grad)
self._parameters[key].grad = grad_applied.requires_grad_(
param.grad.requires_grad
)
for (key, buf) in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
......
......@@ -142,9 +142,6 @@ class Adam(Optimizer):
for param in param_group.parameters:
assert param.is_leaf, "parameters must be leaf tensor"
self._state[param] = dict()
self._state[param]["exp_avg"] = flow.zeros_like(param)
self._state[param]["exp_avg_sq"] = flow.zeros_like(param)
self._state[param]["max_exp_avg_sq"] = flow.zeros_like(param)
self._op = (
flow.builtin_op("adam_update")
......@@ -193,6 +190,12 @@ class Adam(Optimizer):
for param in param_group.parameters:
if param.grad is None:
continue
if "exp_avg" not in self._state[param]:
self._state[param]["exp_avg"] = flow.zeros_like(param)
if "exp_avg_sq" not in self._state[param]:
self._state[param]["exp_avg_sq"] = flow.zeros_like(param)
if "max_exp_avg_sq" not in self._state[param]:
self._state[param]["max_exp_avg_sq"] = flow.zeros_like(param)
m_tensor = self._state[param]["exp_avg"]
v_tensor = self._state[param]["exp_avg_sq"]
max_v_tensor = self._state[param]["max_exp_avg_sq"]
......
......@@ -144,9 +144,6 @@ class AdamW(Optimizer):
for param in param_group.parameters:
assert param.is_leaf, "parameters must be leaf tensor"
self._state[param] = dict()
self._state[param]["exp_avg"] = flow.zeros_like(param)
self._state[param]["exp_avg_sq"] = flow.zeros_like(param)
self._state[param]["max_exp_avg_sq"] = flow.zeros_like(param)
self._op = (
flow.builtin_op("adam_update")
......@@ -196,6 +193,12 @@ class AdamW(Optimizer):
if param.grad is None:
continue
if "exp_avg" not in self._state[param]:
self._state[param]["exp_avg"] = flow.zeros_like(param)
if "exp_avg_sq" not in self._state[param]:
self._state[param]["exp_avg_sq"] = flow.zeros_like(param)
if "max_exp_avg_sq" not in self._state[param]:
self._state[param]["max_exp_avg_sq"] = flow.zeros_like(param)
m_tensor = self._state[param]["exp_avg"]
v_tensor = self._state[param]["exp_avg_sq"]
max_v_tensor = self._state[param]["max_exp_avg_sq"]
......
......@@ -232,25 +232,14 @@ class Optimizer(object):
3. Optimizers have a different behavior if the gradient is 0 or None
(in one case it does the step with a gradient of 0 and in the other
it skips the step altogether).
Returns:
None
"""
all_grad_is_none = True
for param_group in self.param_groups:
for param in param_group.parameters:
if param.grad is not None:
all_grad_is_none = False
if set_to_none:
param.grad = None
else:
param.grad.zeros_()
if all_grad_is_none:
warnings.warn(
"\nParameters in optimizer do not have gradient.\nPlease check `loss.backward()` is called"
"or not,\nor try to declare optimizer after calling `module.to()`"
)
def _parse_input_parameters(self, parameters):
"""
......
......@@ -153,9 +153,7 @@ class RMSprop(Optimizer):
for param in param_group.parameters:
assert param.is_leaf, "parameters must be leaf tensor"
self._state[param] = dict()
self._state[param]["square_avg"] = flow.zeros_like(param)
if param_group["centered"]:
self._state[param]["grad_avg"] = flow.zeros_like(param)
self._centered_rmsprop = (
flow.builtin_op("rmsprop_update")
.Input("model")
......@@ -199,8 +197,14 @@ class RMSprop(Optimizer):
for param in param_group.parameters:
if param.grad is None:
continue
if "square_avg" not in self._state[param]:
self._state[param]["square_avg"] = flow.zeros_like(param)
ms_tensor = self._state[param]["square_avg"]
if param_group["centered"]:
if "grad_avg" not in self._state[param]:
self._state[param]["grad_avg"] = flow.zeros_like(param)
mg_tensor = self._state[param]["grad_avg"]
self._centered_rmsprop(
param, param.grad, ms_tensor, mg_tensor, **kwargs
......
......@@ -115,8 +115,7 @@ class SGD(Optimizer):
for param in param_group.parameters:
assert param.is_leaf, "parameters must be leaf tensor"
self._state[param] = dict()
if param_group["momentum"] != 0.0:
self._state[param]["momentum_buf"] = flow.zeros_like(param)
self._momentum_sgd = (
flow.builtin_op("momentum_update")
.Input("model")
......@@ -149,6 +148,8 @@ class SGD(Optimizer):
if param_group["momentum"] == 0.0:
self._sgd(param, param.grad, learning_rate_val=lr, l2=l2)
else:
if "momentum_buf" not in self._state[param]:
self._state[param]["momentum_buf"] = flow.zeros_like(param)
momentum_buf = self._state[param]["momentum_buf"]
beta = param_group["momentum"]
self._momentum_sgd(
......
......@@ -20,6 +20,7 @@ import unittest
import oneflow as flow
import oneflow.nn as nn
import oneflow.unittest
from data_utils import load_data_fashion_mnist
......@@ -71,6 +72,8 @@ def test_train_and_eval(test_case):
else:
device = flow.device("cuda")
net = LeNet()
lr, num_epochs = 0.02, 1
optimizer = flow.optim.SGD(net.parameters(), lr=lr, momentum=0.9)
net.to(device)
batch_size = 256
......@@ -90,8 +93,6 @@ def test_train_and_eval(test_case):
loss = nn.CrossEntropyLoss()
loss.to(device)
lr, num_epochs = 0.02, 1
optimizer = flow.optim.SGD(net.parameters(), lr=lr, momentum=0.9)
final_accuracy = 0
for epoch in range(num_epochs):
......
......@@ -15,8 +15,8 @@ limitations under the License.
"""
import os
import unittest
import numpy as np
import numpy as np
import oneflow as flow
import oneflow.unittest
......@@ -170,7 +170,7 @@ class TestScalarGraph(oneflow.unittest.TestCase):
_test_scalar_consistent_train_graph(test_case, flow.placement("cuda", {0: [0]}))
def test_scalar_consistent_train_graph_cpu(test_case):
_test_scalar_consistent_train_graph(test_case, flow.placement("cuda", {0: [0]}))
_test_scalar_consistent_train_graph(test_case, flow.placement("cpu", {0: [0]}))
if __name__ == "__main__":
......
......@@ -18,11 +18,11 @@ import unittest
from collections import OrderedDict
import numpy as np
from automated_test_util import *
import oneflow as flow
import oneflow.unittest
from automated_test_util import *
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
class TestTensor(flow.unittest.TestCase):
......@@ -61,6 +61,21 @@ class TestTensor(flow.unittest.TestCase):
y = flow.Tensor(x, device="cuda")
test_case.assertTrue(y.is_local)
@flow.unittest.skip_unless_1n1d()
def test_consistent_set_data(test_case):
x_placement = flow.placement("cpu", {0: 0})
x_sbp = flow.sbp.broadcast
x = flow.ones(2, 3, placement=x_placement, sbp=x_sbp)
y_placement = flow.placement("cuda", {0: 0})
y_sbp = flow.sbp.split(0)
y = flow.ones(4, 5, placement=y_placement, sbp=y_sbp)
old_id = id(x)
x.data = y
test_case.assertEqual(old_id, id(x))
test_case.assertTrue(x.shape == (4, 5))
test_case.assertTrue(x.placement == y_placement)
test_case.assertTrue(x.sbp[0] == y_sbp)
@flow.unittest.skip_unless_1n1d()
def test_consistent_tensor_autograd_related_methods(test_case):
placement = flow.placement("cuda", {0: 0})
......
......@@ -18,11 +18,11 @@ import unittest
from collections import OrderedDict
import numpy as np
from automated_test_util import *
import oneflow as flow
import oneflow.unittest
from automated_test_util import *
@flow.unittest.skip_unless_1n1d()
class TestParameter(flow.unittest.TestCase):
......@@ -33,6 +33,16 @@ class TestParameter(flow.unittest.TestCase):
z = torch.nn.Parameter(y)
return z.grad_fn
def test_parameter_set_data(test_case):
a = flow.nn.Parameter(flow.ones(2, 3), False)
old_id = id(a)
b = flow.nn.Parameter(flow.ones(4, 5), True)
a.data = b
test_case.assertEqual(old_id, id(a))
test_case.assertTrue(a.shape == (4, 5))
test_case.assertTrue(a.requires_grad)
test_case.assertTrue(a.is_leaf)
if __name__ == "__main__":
unittest.main()
......@@ -18,11 +18,11 @@ import unittest
from collections import OrderedDict
import numpy as np
from automated_test_util import *
import oneflow as flow
import oneflow.unittest
from automated_test_util import *
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
class TestTensor(flow.unittest.TestCase):
......@@ -168,6 +168,18 @@ class TestTensor(flow.unittest.TestCase):
x = flow.Tensor(*shape, device=flow.device("cpu"))
test_case.assertTrue(not x.is_cuda)
@flow.unittest.skip_unless_1n1d()
def test_tensor_set_data(test_case):
a = flow.ones(2, 3, requires_grad=False)
b = flow.ones(4, 5, requires_grad=True).to("cuda")
old_id = id(a)
a.data = b
test_case.assertEqual(old_id, id(a))
test_case.assertTrue(a.shape == (4, 5))
test_case.assertTrue(a.device == flow.device("cuda"))
test_case.assertTrue(a.requires_grad)
test_case.assertFalse(a.is_leaf)
@flow.unittest.skip_unless_1n1d()
def test_tensor_unsupported_property(test_case):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册