提交 29697c2e 编写于 作者: M minqiyang

Add stop_gradient to VarBase to support loss function

test=develop
上级 fba3712a
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
syntax = "proto2"; syntax = "proto2";
option optimize_for = LITE_RUNTIME; /* option optimize_for = LITE_RUNTIME; */
package paddle.framework.proto; package paddle.framework.proto;
// Any incompatible changes to ProgramDesc and its dependencies should // Any incompatible changes to ProgramDesc and its dependencies should
......
...@@ -115,6 +115,7 @@ framework::Variable* CreateVariable(const std::string& name, ...@@ -115,6 +115,7 @@ framework::Variable* CreateVariable(const std::string& name,
varname = string::Sprintf("%s@%d", varname, id); varname = string::Sprintf("%s@%d", varname, id);
} }
LOG(ERROR) << "creating var " << varname;
VLOG(3) << "creating var " << varname; VLOG(3) << "creating var " << varname;
framework::Variable* var = scope->Var(varname); framework::Variable* var = scope->Var(varname);
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>(); framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
...@@ -130,13 +131,22 @@ framework::LoDTensor& VarBase::Grad() { ...@@ -130,13 +131,22 @@ framework::LoDTensor& VarBase::Grad() {
} }
void VarBase::ApplyGrad(framework::Scope* scope, Variable* grad) { void VarBase::ApplyGrad(framework::Scope* scope, Variable* grad) {
PADDLE_ENFORCE(grad->IsInitialized(), "grad %s must be initialized",
var_desc_->Name());
PADDLE_ENFORCE(grad->Get<framework::LoDTensor>().IsInitialized(),
"variable %s has NO gradient, please set stop_gradient to it",
var_desc_->Name());
VLOG(3) << "apply var grad " << var_desc_->Name() << " " VLOG(3) << "apply var grad " << var_desc_->Name() << " "
<< grad->Get<framework::LoDTensor>().data<float>()[0]; << grad->Get<framework::LoDTensor>().data<float>()[0];
if (!grads_) { if (!grads_) {
grads_ = grads_ =
CreateVariable(string::Sprintf("%s@IGrad", var_desc_->Name()), CreateVariable(string::Sprintf("%s@IGrad", var_desc_->Name()),
var_->Get<framework::LoDTensor>().dims(), 0.0, scope); var_->Get<framework::LoDTensor>().dims(), 0.0, scope);
} }
AddTo(grad, grads_); AddTo(grad, grads_);
VLOG(3) << "grad_ after apply var grad " << var_desc_->Name() << " " VLOG(3) << "grad_ after apply var grad " << var_desc_->Name() << " "
<< grads_->Get<framework::LoDTensor>().data<float>()[0]; << grads_->Get<framework::LoDTensor>().data<float>()[0];
...@@ -153,7 +163,8 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) { ...@@ -153,7 +163,8 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
// grad op inputs can be forward inputs, so not in grad_to_var. // grad op inputs can be forward inputs, so not in grad_to_var.
continue; continue;
} }
VLOG(3) << "op grad in var " << grad_invar; VLOG(3) << "op grad input var " << grad_invar;
framework::VarDesc& grad_invar_desc =
block_->FindRecursiveOrCreateVar(grad_invar); block_->FindRecursiveOrCreateVar(grad_invar);
framework::Variable* var = scope->Var(grad_invar); framework::Variable* var = scope->Var(grad_invar);
const std::string& invar = grad_to_var_->at(grad_invar); const std::string& invar = grad_to_var_->at(grad_invar);
...@@ -165,21 +176,33 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) { ...@@ -165,21 +176,33 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
break; break;
} }
} }
grad_invar_desc.SetShape(
framework::vectorize(var->Get<framework::LoDTensor>().dims()));
VLOG(3)
<< "set op grad var desc's shape size "
<< framework::vectorize(var->Get<framework::LoDTensor>().dims()).size();
} }
LOG(ERROR) << "grad_op_desc_" << grad_op_desc_->Proto()->DebugString();
for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) {
VLOG(3) << "grad outvar " << outvar; VLOG(3) << "op grad output var " << outvar;
block_->FindRecursiveOrCreateVar(outvar); block_->FindRecursiveOrCreateVar(outvar);
framework::Variable* var = scope->Var(outvar); framework::Variable* var = scope->Var(outvar);
if (!var->IsInitialized()) { if (!var->IsInitialized()) {
VLOG(3) << "init op grad output var " << outvar;
framework::VarDesc* var_desc = block_->FindVar(outvar); framework::VarDesc* var_desc = block_->FindVar(outvar);
if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) { if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
var->GetMutable<framework::LoDTensor>(); var->GetMutable<framework::LoDTensor>();
// framework::Tensor* tensor = var->GetMutable<framework::LoDTensor>();
// tensor->mutable_data(platform::CPUPlace());
} else { } else {
LOG(ERROR) << "tracer doesn't support yet"; LOG(ERROR) << "tracer doesn't support yet";
} }
} }
VLOG(3) << "op grad output var " << outvar << " is inited";
} }
grad_op_desc_->InferShape(*block_); grad_op_desc_->InferShape(*block_);
grad_op_desc_->InferVarType(block_); grad_op_desc_->InferVarType(block_);
std::unique_ptr<framework::OperatorBase> opbase = std::unique_ptr<framework::OperatorBase> opbase =
...@@ -194,11 +217,15 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) { ...@@ -194,11 +217,15 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
VarBase* origin_var = (*input_vars_)[i]; VarBase* origin_var = (*input_vars_)[i];
for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) {
Variable* var = scope->FindVar(outvar); Variable* var = scope->FindVar(outvar);
std::string orig_var = grad_to_var_->at(outvar); if (var->IsInitialized()) {
if (origin_var->var_desc_->Name() != orig_var) { VLOG(3) << "get grad op output var " << outvar;
}
std::string orig_var_name = grad_to_var_->at(outvar);
if (origin_var->var_desc_->Name() != orig_var_name ||
origin_var->stop_gradient_) {
continue; continue;
} }
VLOG(3) << "apply grad " << outvar << " with origin " << orig_var; VLOG(3) << "apply grad " << outvar << " with origin " << orig_var_name;
origin_var->ApplyGrad(scope, var); origin_var->ApplyGrad(scope, var);
found = true; found = true;
ret.push_back(var); ret.push_back(var);
......
...@@ -29,12 +29,13 @@ class OpBase; ...@@ -29,12 +29,13 @@ class OpBase;
class VarBase { class VarBase {
public: public:
VarBase() explicit VarBase(bool stop_gradient = false)
: pre_op_(nullptr), : pre_op_(nullptr),
pre_op_out_idx_(-1), pre_op_out_idx_(-1),
var_desc_(nullptr), var_desc_(nullptr),
var_(nullptr), var_(nullptr),
grads_(nullptr) {} grads_(nullptr),
stop_gradient_(stop_gradient) {}
virtual ~VarBase() {} virtual ~VarBase() {}
...@@ -50,6 +51,8 @@ class VarBase { ...@@ -50,6 +51,8 @@ class VarBase {
framework::VarDesc* var_desc_; framework::VarDesc* var_desc_;
framework::Variable* var_; framework::Variable* var_;
framework::Variable* grads_; framework::Variable* grads_;
bool stop_gradient_;
}; };
class OpBase { class OpBase {
......
...@@ -110,6 +110,8 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> { ...@@ -110,6 +110,8 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* label = ctx.Input<Tensor>("Label"); auto* label = ctx.Input<Tensor>("Label");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
LOG(ERROR) << "CROSS ENTROPY GRAD DX: "
<< ctx.op().Output(framework::GradVarName("X"));
T* dx_data = dx->mutable_data<T>(ctx.GetPlace()); T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
// Following computation only depends on the last dimension size. So it's // Following computation only depends on the last dimension size. So it's
......
...@@ -111,7 +111,8 @@ PYBIND11_MODULE(core, m) { ...@@ -111,7 +111,8 @@ PYBIND11_MODULE(core, m) {
BindException(&m); BindException(&m);
py::class_<imperative::VarBase, PyVarBase>(m, "VarBase", R"DOC()DOC") py::class_<imperative::VarBase, PyVarBase>(m, "VarBase", R"DOC()DOC")
.def(py::init<>()) // .def(py::init<>())
.def(py::init<bool>(), py::arg("stop_gradient") = false)
.def("_run_backward", .def("_run_backward",
[](imperative::VarBase &self, framework::Scope *scope) { [](imperative::VarBase &self, framework::Scope *scope) {
self.RunBackward(scope); self.RunBackward(scope);
...@@ -129,7 +130,13 @@ PYBIND11_MODULE(core, m) { ...@@ -129,7 +130,13 @@ PYBIND11_MODULE(core, m) {
[](imperative::VarBase &self, framework::VarDesc *var_desc) { [](imperative::VarBase &self, framework::VarDesc *var_desc) {
self.var_desc_ = var_desc; self.var_desc_ = var_desc;
}, },
py::return_value_policy::reference); py::return_value_policy::reference)
.def_property(
"stop_gradient",
[](const imperative::VarBase &self) { return self.stop_gradient_; },
[](imperative::VarBase &self, bool stop_gradient) {
self.stop_gradient_ = stop_gradient;
});
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC") py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
.def(py::init<>()) .def(py::init<>())
......
...@@ -354,11 +354,11 @@ class Variable(object): ...@@ -354,11 +354,11 @@ class Variable(object):
self.block.vars[name] = self self.block.vars[name] = self
self.op = None self.op = None
self.stop_gradient = stop_gradient
self.is_data = is_data self.is_data = is_data
if _in_imperative_mode(): if _in_imperative_mode():
self._ivar = core.VarBase() self._ivar = core.VarBase()
self._ivar.desc = self.desc self._ivar.desc = self.desc
self._ivar.stop_gradient = stop_gradient
def _numpy(self): def _numpy(self):
scope = _imperative_tracer().get_scope() scope = _imperative_tracer().get_scope()
...@@ -366,7 +366,7 @@ class Variable(object): ...@@ -366,7 +366,7 @@ class Variable(object):
return np.array(tensor) return np.array(tensor)
def _backward(self): def _backward(self):
scope = _imperative_tracer().get_scope(self.block.desc) scope = _imperative_tracer().get_scope()
self._ivar._run_backward(scope) self._ivar._run_backward(scope)
def _gradient(self): def _gradient(self):
...@@ -415,6 +415,14 @@ class Variable(object): ...@@ -415,6 +415,14 @@ class Variable(object):
""" """
self.desc = input self.desc = input
@property
def _stop_gradient(self):
return self._ivar.stop_gradient
@_stop_gradient.setter
def _stop_gradient(self, s):
self._ivar.stop_gradient = s
@property @property
def persistable(self): def persistable(self):
return self.desc.persistable() return self.desc.persistable()
......
...@@ -25,12 +25,22 @@ __all__ = ['PyLayer'] ...@@ -25,12 +25,22 @@ __all__ = ['PyLayer']
class PyLayer(core.Layer): class PyLayer(core.Layer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._once_built = True
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
self._helper = LayerHelper(type(self).__name__, **kwargs) self._helper = LayerHelper(type(self).__name__, **kwargs)
self._dtype = kwargs.get("dtype", core.VarDesc.VarType.FP32) self._dtype = kwargs.get("dtype", core.VarDesc.VarType.FP32)
def _build_once(self, inputs):
pass
def __call__(self, *inputs): def __call__(self, *inputs):
if self._once_built:
self._build_once(*inputs)
self._once_built = False
outputs = self.forward(*inputs) outputs = self.forward(*inputs)
return outputs return outputs
def forward(self, *inputs): def forward(self, *inputs):
......
...@@ -18,14 +18,15 @@ import numpy as np ...@@ -18,14 +18,15 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.imperative.nn import Conv2D, Pool2D from paddle.fluid.imperative.nn import Conv2D, Pool2D, FC
from paddle.fluid.imperative.base import to_variable
class SimpleImgConvPool(fluid.imperative.PyLayer): class SimpleImgConvPool(fluid.imperative.PyLayer):
def __init__(self, def __init__(self,
num_channels, num_channels,
num_filters,
filter_size, filter_size,
num_filters,
pool_size, pool_size,
pool_stride, pool_stride,
pool_padding=0, pool_padding=0,
...@@ -81,24 +82,24 @@ class MNIST(fluid.imperative.PyLayer): ...@@ -81,24 +82,24 @@ class MNIST(fluid.imperative.PyLayer):
super(MNIST, self).__init__(param_attr=param_attr, bias_attr=bias_attr) super(MNIST, self).__init__(param_attr=param_attr, bias_attr=bias_attr)
self._simple_img_conv_pool_1 = SimpleImgConvPool( self._simple_img_conv_pool_1 = SimpleImgConvPool(
num_channels=3, 1, 5, 20, 2, 2, act="relu")
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool( self._simple_img_conv_pool_2 = SimpleImgConvPool(
num_channels=3, 20, 5, 50, 2, 2, act="relu")
filter_size=5,
num_filters=50, pool_2_shape = 50 * 8 * 8
pool_size=2, SIZE = 10
pool_stride=2, scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
act="relu") self._fc = FC(-1,
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)))
def forward(self, inputs): def forward(self, inputs):
x = self._simple_img_conv_pool_1(inputs) x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x) x = self._simple_img_conv_pool_2(x)
x = self._fc(x)
return x return x
...@@ -107,8 +108,20 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -107,8 +108,20 @@ class TestImperativeMnist(unittest.TestCase):
with fluid.imperative.guard(): with fluid.imperative.guard():
mnist = MNIST() mnist = MNIST()
data = np.random.rand(2, 3, 5, 5).astype('float32') x_data = np.random.rand(128, 1, 28, 28).astype('float32')
mnist(data) img = to_variable(x_data)
y_data = np.random.rand(128, 1).astype('int64')
label = to_variable(y_data)
label._stop_gradient = True
predict = mnist(img)
print(predict.shape, predict.dtype, label.shape, label.dtype)
out = fluid.layers.cross_entropy(predict, label)
print(out.shape, out.dtype)
out._backward()
filter_grad = mnist._simple_img_conv_pool_1._conv2d._filter_param._gradient(
)
print(filter_grad)
# np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) # np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
# with fluid.imperative.guard(): # with fluid.imperative.guard():
# mlp = MLP() # mlp = MLP()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册