未验证 提交 23761bea 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge pull request #14971 from velconia/imperative_mnist

Imperative Optimizer
...@@ -69,6 +69,15 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -69,6 +69,15 @@ inline std::string GradVarName(const std::string& var_name) {
return result; return result;
} }
inline std::string GradOriginalVarName(const std::string& grad_var_name) {
std::size_t pos = grad_var_name.rfind(kGradVarSuffix);
if (pos == std::string::npos) {
return grad_var_name;
} else {
return grad_var_name.substr(0, pos);
}
}
proto::VarType::Type GetDataTypeOfVar(const Variable* var); proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var); const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
......
...@@ -288,3 +288,30 @@ TEST(OpKernel, multi_inputs) { ...@@ -288,3 +288,30 @@ TEST(OpKernel, multi_inputs) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_place); op->Run(scope, cpu_place);
} }
TEST(VarNameTest, all) {
std::string var_name("X");
std::string grad_var_name = paddle::framework::GradVarName(var_name);
ASSERT_EQ(grad_var_name, "X@GRAD");
std::string original_var_name =
paddle::framework::GradOriginalVarName(grad_var_name);
ASSERT_EQ(original_var_name, "X");
original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
ASSERT_EQ(original_var_name, "X");
std::string var_name_2("XYZ");
grad_var_name = paddle::framework::GradVarName(var_name_2);
ASSERT_EQ(grad_var_name, "XYZ@GRAD");
original_var_name = paddle::framework::GradOriginalVarName(grad_var_name);
ASSERT_EQ(original_var_name, "XYZ");
original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
ASSERT_EQ(original_var_name, "XYZ");
std::string var_name_3("");
grad_var_name = paddle::framework::GradVarName(var_name_3);
ASSERT_EQ(grad_var_name, "@GRAD");
original_var_name = paddle::framework::GradOriginalVarName(grad_var_name);
ASSERT_EQ(original_var_name, "");
original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
ASSERT_EQ(original_var_name, "");
}
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
...@@ -31,8 +32,14 @@ using framework::Variable; ...@@ -31,8 +32,14 @@ using framework::Variable;
void AddTo(Variable* src, Variable* dst) { void AddTo(Variable* src, Variable* dst) {
framework::LoDTensor* dst_tensor = dst->GetMutable<framework::LoDTensor>(); framework::LoDTensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
framework::LoDTensor* src_tensor = src->GetMutable<framework::LoDTensor>(); framework::LoDTensor* src_tensor = src->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(), "%lld vs %lld", // FIXME(minqiyang): loss_grad op will pass a zero grad of label
dst_tensor->numel(), src_tensor->numel()); // ugly fix for it
if (src_tensor->numel() == 0) {
return;
}
PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
"dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
src_tensor->numel());
float* dst_data = dst_tensor->mutable_data<float>(platform::CPUPlace()); float* dst_data = dst_tensor->mutable_data<float>(platform::CPUPlace());
const float* src_data = src_tensor->data<float>(); const float* src_data = src_tensor->data<float>();
for (size_t i = 0; i < src_tensor->numel(); ++i) { for (size_t i = 0; i < src_tensor->numel(); ++i) {
...@@ -45,6 +52,10 @@ class Autograd { ...@@ -45,6 +52,10 @@ class Autograd {
Autograd() {} Autograd() {}
void RunBackward(VarBase* var) { void RunBackward(VarBase* var) {
if (var->stop_gradient_) {
return;
}
std::deque<OpBase*> ready; std::deque<OpBase*> ready;
ready.push_back(var->pre_op_); ready.push_back(var->pre_op_);
...@@ -60,6 +71,9 @@ class Autograd { ...@@ -60,6 +71,9 @@ class Autograd {
const std::vector<VarBase*>& ingrads = it.second; const std::vector<VarBase*>& ingrads = it.second;
for (size_t i = 0; i < ingrads.size(); ++i) { for (size_t i = 0; i < ingrads.size(); ++i) {
if (!ingrads[i]) continue; if (!ingrads[i]) continue;
if (ready_op->input_vars_[it.first][i]->stop_gradient_) {
continue;
}
OpBase* pre_op = ready_op->pre_ops_[it.first][i]; OpBase* pre_op = ready_op->pre_ops_[it.first][i];
if (!pre_op) continue; if (!pre_op) continue;
...@@ -107,7 +121,7 @@ framework::LoDTensor& VarBase::Grad() { ...@@ -107,7 +121,7 @@ framework::LoDTensor& VarBase::Grad() {
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
if (!grad_op_desc_) { if (!grad_op_desc_) {
VLOG(3) << "op with no grad: " << op_desc_->Type(); LOG(WARNING) << "op with no grad: " << op_desc_->Type();
return {}; return {};
} }
VLOG(3) << "op grad " << grad_op_desc_->Type(); VLOG(3) << "op grad " << grad_op_desc_->Type();
...@@ -117,15 +131,18 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -117,15 +131,18 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
for (auto it : grad_output_vars_) { for (auto it : grad_output_vars_) {
auto& outputs = grad_outputs[it.first]; auto& outputs = grad_outputs[it.first];
for (size_t i = 0; i < it.second.size(); ++i) { for (size_t i = 0; i < it.second.size(); ++i) {
tmp_vars.emplace_back(new framework::Variable()); // Allocate a new variable
outputs.push_back(tmp_vars.back().get()); Variable* tmp_var = new framework::Variable();
outputs.back()->GetMutable<framework::LoDTensor>(); tmp_var->GetMutable<framework::LoDTensor>();
tmp_vars.emplace_back(tmp_var);
outputs.push_back(tmp_var);
} }
} }
framework::RuntimeContext ctx(grad_input_vars_, grad_outputs); framework::RuntimeContext ctx(grad_input_vars_, grad_outputs);
// No need to do static infer shape here. // No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_); // grad_op_desc_->InferShape(*block_);
grad_op_desc_->InferVarType(block_); grad_op_desc_->InferVarType(block_);
...@@ -144,6 +161,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -144,6 +161,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
for (auto it : grad_output_vars_) { for (auto it : grad_output_vars_) {
auto& outputs = grad_outputs[it.first]; auto& outputs = grad_outputs[it.first];
auto& origin_outputs = it.second; auto& origin_outputs = it.second;
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
framework::Variable* orig_grad = origin_outputs[i]; framework::Variable* orig_grad = origin_outputs[i];
AddTo(outputs[i], orig_grad); AddTo(outputs[i], orig_grad);
......
...@@ -86,23 +86,30 @@ class VarBase { ...@@ -86,23 +86,30 @@ class VarBase {
pre_op_out_idx_(-1), pre_op_out_idx_(-1),
var_desc_(nullptr), var_desc_(nullptr),
var_(new framework::Variable()), var_(new framework::Variable()),
grads_(new framework::Variable()) {} grads_(new framework::Variable()),
stop_gradient_(false) {}
virtual ~VarBase() { explicit VarBase(bool stop_gradient)
if (var_) { : pre_op_(nullptr),
delete var_; pre_op_out_idx_(-1),
var_ = nullptr; var_desc_(nullptr),
} var_(new framework::Variable()),
if (grads_) { grads_(new framework::Variable()),
delete grads_; stop_gradient_(stop_gradient) {}
grads_ = nullptr;
} virtual ~VarBase() {}
}
void RunBackward(); void RunBackward();
framework::LoDTensor& Grad(); framework::LoDTensor& Grad();
inline std::string GradName() const {
PADDLE_ENFORCE(
var_desc_,
"Couldn't get gradient variable's name, please call backward() first");
return string::Sprintf("%s@IGrad", var_desc_->Name());
}
OpBase* pre_op_; OpBase* pre_op_;
std::string pre_op_out_name_; std::string pre_op_out_name_;
int pre_op_out_idx_; int pre_op_out_idx_;
...@@ -110,6 +117,8 @@ class VarBase { ...@@ -110,6 +117,8 @@ class VarBase {
framework::VarDesc* var_desc_; framework::VarDesc* var_desc_;
framework::Variable* var_; framework::Variable* var_;
framework::Variable* grads_; framework::Variable* grads_;
bool stop_gradient_;
}; };
class OpBase { class OpBase {
......
...@@ -50,16 +50,14 @@ void InitVar(framework::Variable* var, framework::Variable* grad_var) { ...@@ -50,16 +50,14 @@ void InitVar(framework::Variable* var, framework::Variable* grad_var) {
class Tracer { class Tracer {
public: public:
explicit Tracer(framework::BlockDesc* root_block, explicit Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {}
framework::BlockDesc* startup_block)
: root_block_(root_block), startup_block_(startup_block) {}
virtual ~Tracer() {} virtual ~Tracer() {}
void Trace(OpBase* op, void Trace(OpBase* op,
const std::map<std::string, std::vector<VarBase*>>& inputs, const std::map<std::string, std::vector<VarBase*>>& inputs,
const std::map<std::string, std::vector<VarBase*>>& outputs, const std::map<std::string, std::vector<VarBase*>>& outputs,
framework::BlockDesc* block) { framework::BlockDesc* block, const bool stop_gradient = false) {
std::map<std::string, VarBase*> vars; std::map<std::string, VarBase*> vars;
framework::OpDesc* op_desc = op->op_desc_; framework::OpDesc* op_desc = op->op_desc_;
...@@ -107,6 +105,7 @@ class Tracer { ...@@ -107,6 +105,7 @@ class Tracer {
} else { } else {
LOG(ERROR) << "tracer doesn't support yet"; LOG(ERROR) << "tracer doesn't support yet";
} }
out->stop_gradient_ = stop_gradient;
out->pre_op_ = op; out->pre_op_ = op;
out->pre_op_out_name_ = it.first; out->pre_op_out_name_ = it.first;
out->pre_op_out_idx_ = i; out->pre_op_out_idx_ = i;
...@@ -130,9 +129,7 @@ class Tracer { ...@@ -130,9 +129,7 @@ class Tracer {
p.op.RuntimeInferShape(scope, place, ctx); p.op.RuntimeInferShape(scope, place, ctx);
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx)); p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
if (block == startup_block_) { if (!stop_gradient) {
op->grad_op_desc_ = nullptr;
} else {
framework::OpDesc* grad_op_desc; framework::OpDesc* grad_op_desc;
auto grad_to_var = new std::unordered_map<std::string, std::string>(); auto grad_to_var = new std::unordered_map<std::string, std::string>();
CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var); CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var);
...@@ -156,6 +153,7 @@ class Tracer { ...@@ -156,6 +153,7 @@ class Tracer {
} }
} }
} }
for (auto it : grad_op_desc->Outputs()) { for (auto it : grad_op_desc->Outputs()) {
auto& grad_out_vars = op->grad_output_vars_[it.first]; auto& grad_out_vars = op->grad_output_vars_[it.first];
for (const std::string& grad_outvar : it.second) { for (const std::string& grad_outvar : it.second) {
...@@ -170,12 +168,12 @@ class Tracer { ...@@ -170,12 +168,12 @@ class Tracer {
} }
} }
} }
op->block_ = block; op->block_ = block;
} }
private: private:
framework::BlockDesc* root_block_; framework::BlockDesc* root_block_;
framework::BlockDesc* startup_block_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -23,9 +23,8 @@ namespace pybind { ...@@ -23,9 +23,8 @@ namespace pybind {
void BindTracer(pybind11::module *m) { void BindTracer(pybind11::module *m) {
pybind11::class_<imperative::Tracer>(*m, "Tracer", "") pybind11::class_<imperative::Tracer>(*m, "Tracer", "")
.def("__init__", .def("__init__",
[](imperative::Tracer &self, framework::BlockDesc *root_block, [](imperative::Tracer &self, framework::BlockDesc *root_block) {
framework::BlockDesc *startup_block) { new (&self) imperative::Tracer(root_block);
new (&self) imperative::Tracer(root_block, startup_block);
}) })
.def("trace", &imperative::Tracer::Trace); .def("trace", &imperative::Tracer::Trace);
} }
......
...@@ -125,11 +125,26 @@ PYBIND11_MODULE(core, m) { ...@@ -125,11 +125,26 @@ PYBIND11_MODULE(core, m) {
m.add_object("_cleanup", m.add_object("_cleanup",
py::capsule([]() { ScopePool::Instance().Clear(); })); py::capsule([]() { ScopePool::Instance().Clear(); }));
py::class_<imperative::VarBase, PyVarBase>(m, "VarBase", R"DOC()DOC") py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
.def(py::init<>()) m, "VarBase", R"DOC()DOC")
// .def(py::init<>())
.def(py::init<bool>(), py::arg("stop_gradient") = false)
.def("_run_backward", .def("_run_backward",
[](imperative::VarBase &self) { self.RunBackward(); }) [](imperative::VarBase &self) { self.RunBackward(); })
.def("_grad_name", &imperative::VarBase::GradName)
.def("_grad", &imperative::VarBase::Grad) .def("_grad", &imperative::VarBase::Grad)
.def_property("grad_value",
[](const imperative::VarBase &self) { return self.grads_; },
[](imperative::VarBase &self, framework::Variable *grad) {
self.grads_ = grad;
},
py::return_value_policy::reference)
.def_property("value",
[](const imperative::VarBase &self) { return self.var_; },
[](imperative::VarBase &self, framework::Variable *var) {
self.var_ = var;
},
py::return_value_policy::reference)
.def_property( .def_property(
"desc", "desc",
[](const imperative::VarBase &self) { return self.var_desc_; }, [](const imperative::VarBase &self) { return self.var_desc_; },
...@@ -137,12 +152,12 @@ PYBIND11_MODULE(core, m) { ...@@ -137,12 +152,12 @@ PYBIND11_MODULE(core, m) {
self.var_desc_ = var_desc; self.var_desc_ = var_desc;
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def_property("var", .def_property(
[](const imperative::VarBase &self) { return self.var_; }, "stop_gradient",
[](imperative::VarBase &self, framework::Variable *var) { [](const imperative::VarBase &self) { return self.stop_gradient_; },
self.var_ = var; [](imperative::VarBase &self, bool stop_gradient) {
}, self.stop_gradient_ = stop_gradient;
py::return_value_policy::reference); });
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC") py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
.def(py::init<>()) .def(py::init<>())
......
...@@ -20,7 +20,6 @@ import contextlib ...@@ -20,7 +20,6 @@ import contextlib
import os import os
import re import re
import six import six
import sys
import numpy as np import numpy as np
...@@ -368,9 +367,10 @@ class Variable(object): ...@@ -368,9 +367,10 @@ class Variable(object):
if _in_imperative_mode(): if _in_imperative_mode():
self._ivar = core.VarBase() self._ivar = core.VarBase()
self._ivar.desc = self.desc self._ivar.desc = self.desc
self._ivar.stop_gradient = stop_gradient
def _numpy(self): def _numpy(self):
tensor = self._ivar.var.get_tensor() tensor = self._ivar.value.get_tensor()
return np.array(tensor) return np.array(tensor)
def _backward(self): def _backward(self):
...@@ -379,6 +379,14 @@ class Variable(object): ...@@ -379,6 +379,14 @@ class Variable(object):
def _gradient(self): def _gradient(self):
return np.array(self._ivar._grad()) return np.array(self._ivar._grad())
@property
def _value(self):
return self._ivar.value
@_value.setter
def _value(self, v):
self._ivar.value = v
def __str__(self): def __str__(self):
return self.to_string(True) return self.to_string(True)
...@@ -422,6 +430,14 @@ class Variable(object): ...@@ -422,6 +430,14 @@ class Variable(object):
""" """
self.desc = input self.desc = input
@property
def _stop_gradient(self):
return self._ivar.stop_gradient
@_stop_gradient.setter
def _stop_gradient(self, s):
self._ivar.stop_gradient = s
@property @property
def persistable(self): def persistable(self):
return self.desc.persistable() return self.desc.persistable()
...@@ -681,9 +697,11 @@ class Operator(object): ...@@ -681,9 +697,11 @@ class Operator(object):
self._update_desc_attr(attr_name, attr_val) self._update_desc_attr(attr_name, attr_val)
self.desc.check_attrs() self.desc.check_attrs()
if self._has_kernel(type): if self._has_kernel(type):
self.desc.infer_var_type(self.block.desc) self.desc.infer_var_type(self.block.desc)
self.desc.infer_shape(self.block.desc) self.desc.infer_shape(self.block.desc)
if _in_imperative_mode(): if _in_imperative_mode():
self.iop = core.OpBase() self.iop = core.OpBase()
self.iop.desc = self.desc self.iop.desc = self.desc
...@@ -1266,12 +1284,22 @@ class Block(object): ...@@ -1266,12 +1284,22 @@ class Block(object):
Operator: the append Operator. Operator: the append Operator.
""" """
op_desc = self.desc.append_op() op_desc = self.desc.append_op()
op = Operator(block=self, desc=op_desc, *args, **kwargs) op = Operator(
if _in_imperative_mode(): block=self,
_imperative_tracer().trace(op.iop, op.inputs, op.outputs, self.desc) desc=op_desc,
type=kwargs.get("type", None),
inputs=kwargs.get("inputs", None),
outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None))
self.ops.append(op) self.ops.append(op)
self._trace_op(op, kwargs.get("stop_gradient", False))
return op return op
def _trace_op(self, op, stop_gradient=False):
if _in_imperative_mode():
_imperative_tracer().trace(op.iop, op.inputs, op.outputs, self.desc,
stop_gradient)
def _insert_op(self, index, *args, **kwargs): def _insert_op(self, index, *args, **kwargs):
""" """
Insert a Operator according to the giving arguments. Insert a Operator according to the giving arguments.
...@@ -1317,10 +1345,15 @@ class Block(object): ...@@ -1317,10 +1345,15 @@ class Block(object):
def _prepend_op(self, *args, **kwargs): def _prepend_op(self, *args, **kwargs):
op_desc = self.desc._prepend_op() op_desc = self.desc._prepend_op()
op = Operator(self, op_desc, *args, **kwargs) op = Operator(
if _in_imperative_mode(): self,
_imperative_tracer().trace(op.iop, op.inputs, op.outputs, self.desc) op_desc,
type=kwargs.get("type", None),
inputs=kwargs.get("inputs", None),
outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None))
self.ops.insert(0, op) self.ops.insert(0, op)
self._trace_op(op, kwargs.get("stop_gradient", False))
return op return op
def _sync_with_cpp(self): def _sync_with_cpp(self):
......
...@@ -20,6 +20,10 @@ from .base import * ...@@ -20,6 +20,10 @@ from .base import *
from . import layers from . import layers
from .layers import * from .layers import *
from . import nn
from .nn import *
__all__ = [] __all__ = []
__all__ += layers.__all__ __all__ += layers.__all__
__all__ += base.__all__ __all__ += base.__all__
__all__ += nn.__all__
...@@ -28,8 +28,7 @@ def enabled(): ...@@ -28,8 +28,7 @@ def enabled():
def guard(): def guard():
train = framework.Program() train = framework.Program()
startup = framework.Program() startup = framework.Program()
tracer = core.Tracer(train.current_block().desc, tracer = core.Tracer(train.current_block().desc)
startup.current_block().desc)
with framework.program_guard(train, startup): with framework.program_guard(train, startup):
with framework.unique_name.guard(): with framework.unique_name.guard():
with framework._imperative_guard(tracer): with framework._imperative_guard(tracer):
...@@ -46,7 +45,7 @@ def to_variable(value, block=None): ...@@ -46,7 +45,7 @@ def to_variable(value, block=None):
name=None, name=None,
shape=value.shape, shape=value.shape,
dtype=value.dtype) dtype=value.dtype)
var = py_var._ivar.var var = py_var._ivar.value
tensor = var.get_tensor() tensor = var.get_tensor()
tensor.set(value, core.CPUPlace()) tensor.set(value, core.CPUPlace())
return py_var return py_var
......
...@@ -24,26 +24,21 @@ __all__ = ['PyLayer'] ...@@ -24,26 +24,21 @@ __all__ = ['PyLayer']
class PyLayer(core.Layer): class PyLayer(core.Layer):
def __init__(self): def __init__(self, dtype=core.VarDesc.VarType.FP32, name=None):
self._built = False self._once_built = False
self._dtype = dtype
def __call__(self, inputs):
if not isinstance(inputs, list) and not isinstance(inputs, tuple):
inputs = [inputs]
var_inputs = []
for x in inputs:
py_var = base.to_variable(x)
var_inputs.append(py_var)
if not self._built:
self._build_once(inputs)
self._built = True
outputs = self.forward(var_inputs)
return outputs
def _build_once(self, inputs): def _build_once(self, inputs):
pass pass
def forward(self, inputs): def __call__(self, *inputs):
return [] if not self._once_built:
self._build_once(*inputs)
self._once_built = True
outputs = self.forward(*inputs)
return outputs
def forward(self, *inputs):
raise NotImplementedError
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from six.moves import reduce
from .. import core
from ..layers import utils
from . import layers
from ..framework import Variable, OpProtoHolder
from ..param_attr import ParamAttr
from ..initializer import Normal, Constant
__all__ = [
'Conv2D',
'Pool2D',
'FC',
]
class Conv2D(layers.PyLayer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
padding=0,
dilation=1,
groups=None,
use_cudnn=True,
act=None,
param_attr=None,
bias_attr=None,
name=None,
dtype=core.VarDesc.VarType.FP32):
assert param_attr is not False, "param_attr should not be False here."
super(Conv2D, self).__init__(name=name, dtype=dtype)
from ..layer_helper import LayerHelper
self._helper = LayerHelper(
type(self).__name__,
param_attr=param_attr,
bias_attr=bias_attr,
dtype=dtype,
name=name)
self._groups = groups
self._stride = utils.convert_to_list(stride, 2, 'stride')
self._padding = utils.convert_to_list(padding, 2, 'padding')
self._dilation = utils.convert_to_list(dilation, 2, 'dilation')
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
self._use_cudnn = use_cudnn
self._num_channels = num_channels
if (self._num_channels == self._groups and
num_filters % self._num_channels == 0 and not self._use_cudnn):
self._l_type = 'depthwise_conv2d'
else:
self._l_type = 'conv2d'
if groups is None:
num_filter_channels = num_channels
else:
if num_channels % groups != 0:
raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels // groups
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
filter_shape = [num_filters, int(num_filter_channels)] + filter_size
def _get_default_param_initializer():
filter_elem_num = filter_size[0] * filter_size[1] * num_channels
std = (2.0 / filter_elem_num)**0.5
return Normal(0.0, std, 0)
self._filter_param = self._helper.create_parameter(
attr=self._helper.param_attr,
shape=filter_shape,
dtype=self._dtype,
default_initializer=_get_default_param_initializer())
if self._use_cudnn:
self._helper.create_variable(
name="kCUDNNFwdAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
self._helper.create_variable(
name="kCUDNNBwdDataAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
self._helper.create_variable(
name="kCUDNNBwdFilterAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
self._bias_param = self._helper.create_parameter(
attr=self._helper.bias_attr,
shape=[num_filters],
dtype=self._dtype,
is_bias=True)
def forward(self, input):
pre_bias = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type=self._l_type,
inputs={
'Input': input,
'Filter': self._filter_param,
},
outputs={"Output": pre_bias},
attrs={
'strides': self._stride,
'paddings': self._padding,
'dilations': self._dilation,
'groups': self._groups,
'use_cudnn': self._use_cudnn,
'use_mkldnn': False,
})
pre_act = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type='elementwise_add',
inputs={'X': [pre_bias],
'Y': [self._bias_param]},
outputs={'Out': [pre_act]},
attrs={'axis': 1})
return self._helper.append_activation(pre_act)
class Pool2D(layers.PyLayer):
def __init__(self,
pool_size=-1,
pool_type="max",
pool_stride=1,
pool_padding=0,
global_pooling=False,
use_cudnn=True,
ceil_mode=False,
exclusive=True,
name=None,
dtype=core.VarDesc.VarType.FP32):
if pool_type not in ["max", "avg"]:
raise ValueError(
"Unknown pool_type: '%s'. It can only be 'max' or 'avg'.",
str(pool_type))
if global_pooling is False and pool_size == -1:
raise ValueError(
"When the global_pooling is False, pool_size must be passed "
"and be a valid value. Received pool_size: " + str(pool_size))
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
super(Pool2D, self).__init__(name=name, dtype=dtype)
from ..layer_helper import LayerHelper
self._helper = LayerHelper(type(self).__name__, dtype=dtype, name=name)
self._pool_type = pool_type
self._pool_size = utils.convert_to_list(pool_size, 2, 'pool_size')
self._pool_padding = utils.convert_to_list(pool_padding, 2,
'pool_padding')
self._pool_stride = utils.convert_to_list(pool_stride, 2, 'pool_stride')
self._global_pooling = global_pooling
self._use_cudnn = use_cudnn
self._ceil_mode = ceil_mode
self._exclusive = exclusive
self._l_type = 'pool2d'
def forward(self, input):
pool_out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type=self._l_type,
inputs={"X": input},
outputs={"Out": pool_out},
attrs={
"pooling_type": self._pool_type,
"ksize": self._pool_size,
"global_pooling": self._global_pooling,
"strides": self._pool_stride,
"paddings": self._pool_padding,
"use_cudnn": self._use_cudnn,
"ceil_mode": self._ceil_mode,
"use_mkldnn": False,
"exclusive": self._exclusive,
})
return pool_out
class FC(layers.PyLayer):
def __init__(self,
size,
param_attr=None,
num_flatten_dims=1,
dtype=core.VarDesc.VarType.FP32):
super(FC, self).__init__()
self._size = size
self._num_flatten_dims = num_flatten_dims
self._dtype = dtype
from ..layer_helper import LayerHelper
self._helper = LayerHelper('FC', param_attr=param_attr)
def _build_once(self, input):
input_shape = input.shape
param_shape = [
reduce(lambda a, b: a * b, input_shape[self._num_flatten_dims:], 1)
] + [self._size]
self._w = self._helper.create_parameter(
attr=self._helper.param_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=False)
def forward(self, input):
tmp = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="mul",
inputs={"X": input,
"Y": self._w},
outputs={"Out": tmp},
attrs={
"x_num_col_dims": self._num_flatten_dims,
"y_num_col_dims": 1
})
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="sum",
inputs={"X": [tmp]},
outputs={"Out": out},
attrs={"use_mkldnn": False})
return out
...@@ -162,7 +162,8 @@ class ConstantInitializer(Initializer): ...@@ -162,7 +162,8 @@ class ConstantInitializer(Initializer):
"dtype": int(var.dtype), "dtype": int(var.dtype),
"value": float(self._value), "value": float(self._value),
'force_cpu': self._force_cpu or force_init_on_cpu() 'force_cpu': self._force_cpu or force_init_on_cpu()
}) },
stop_gradient=True)
var.op = op var.op = op
return op return op
...@@ -231,7 +232,8 @@ class UniformInitializer(Initializer): ...@@ -231,7 +232,8 @@ class UniformInitializer(Initializer):
"min": self._low, "min": self._low,
"max": self._high, "max": self._high,
"seed": self._seed "seed": self._seed
}) },
stop_gradient=True)
if var.dtype == VarDesc.VarType.FP16: if var.dtype == VarDesc.VarType.FP16:
block.append_op( block.append_op(
...@@ -309,7 +311,8 @@ class NormalInitializer(Initializer): ...@@ -309,7 +311,8 @@ class NormalInitializer(Initializer):
"std": self._std_dev, "std": self._std_dev,
"seed": self._seed, "seed": self._seed,
"use_mkldnn": False "use_mkldnn": False
}) },
stop_gradient=True)
if var.dtype == VarDesc.VarType.FP16: if var.dtype == VarDesc.VarType.FP16:
block.append_op( block.append_op(
...@@ -371,7 +374,8 @@ class TruncatedNormalInitializer(Initializer): ...@@ -371,7 +374,8 @@ class TruncatedNormalInitializer(Initializer):
"mean": self._mean, "mean": self._mean,
"std": self._std_dev, "std": self._std_dev,
"seed": self._seed "seed": self._seed
}) },
stop_gradient=True)
var.op = op var.op = op
return op return op
...@@ -461,7 +465,8 @@ class XavierInitializer(Initializer): ...@@ -461,7 +465,8 @@ class XavierInitializer(Initializer):
"min": -limit, "min": -limit,
"max": limit, "max": limit,
"seed": self._seed "seed": self._seed
}) },
stop_gradient=True)
else: else:
std = np.sqrt(2.0 / float(fan_in + fan_out)) std = np.sqrt(2.0 / float(fan_in + fan_out))
...@@ -474,7 +479,8 @@ class XavierInitializer(Initializer): ...@@ -474,7 +479,8 @@ class XavierInitializer(Initializer):
"mean": 0.0, "mean": 0.0,
"std": std, "std": std,
"seed": self._seed "seed": self._seed
}) },
stop_gradient=True)
var.op = op var.op = op
return op return op
...@@ -559,7 +565,8 @@ class MSRAInitializer(Initializer): ...@@ -559,7 +565,8 @@ class MSRAInitializer(Initializer):
"min": -limit, "min": -limit,
"max": limit, "max": limit,
"seed": self._seed "seed": self._seed
}) },
stop_gradient=True)
else: else:
std = np.sqrt(2.0 / float(fan_in)) std = np.sqrt(2.0 / float(fan_in))
...@@ -572,7 +579,8 @@ class MSRAInitializer(Initializer): ...@@ -572,7 +579,8 @@ class MSRAInitializer(Initializer):
"mean": 0.0, "mean": 0.0,
"std": std, "std": std,
"seed": self._seed "seed": self._seed
}) },
stop_gradient=True)
var.op = op var.op = op
return op return op
......
...@@ -22,8 +22,8 @@ import numpy as np ...@@ -22,8 +22,8 @@ import numpy as np
from .framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating, _in_imperative_mode from .framework import Variable, Parameter, default_main_program, default_startup_program, dtype_is_floating, _in_imperative_mode
from . import unique_name from . import unique_name
from paddle.fluid.imperative import base as imperative_base
from paddle.fluid.initializer import Constant, Xavier from paddle.fluid.initializer import Constant, Xavier
from paddle.fluid.imperative import base
from .param_attr import ParamAttr, WeightNormParamAttr from .param_attr import ParamAttr, WeightNormParamAttr
from . import core from . import core
from six.moves import zip from six.moves import zip
...@@ -50,7 +50,7 @@ class LayerHelper(object): ...@@ -50,7 +50,7 @@ class LayerHelper(object):
return default_startup_program() return default_startup_program()
def to_variable(self, x): def to_variable(self, x):
return base.to_variable(x, self.main_program.current_block()) return imperative_base.to_variable(x, self.main_program.current_block())
def append_op(self, *args, **kwargs): def append_op(self, *args, **kwargs):
return self.main_program.current_block().append_op(*args, **kwargs) return self.main_program.current_block().append_op(*args, **kwargs)
...@@ -314,11 +314,9 @@ class LayerHelper(object): ...@@ -314,11 +314,9 @@ class LayerHelper(object):
WeightNormParamAttr.params_with_weight_norm.append(param) WeightNormParamAttr.params_with_weight_norm.append(param)
return param return param
if _in_imperative_mode(): if _in_imperative_mode():
self.main_program.global_block().create_parameter(
dtype=dtype, shape=shape, **attr._to_kwargs())
# In imperative mode, we want the returned parameter to be # In imperative mode, we want the returned parameter to be
# initialized so that it can be used imperatively. # initialized so that it can be used imperatively.
return self.startup_program.global_block().create_parameter( return self.main_program.global_block().create_parameter(
dtype=dtype, dtype=dtype,
shape=shape, shape=shape,
**attr._to_kwargs(with_initializer=True)) **attr._to_kwargs(with_initializer=True))
...@@ -380,13 +378,16 @@ class LayerHelper(object): ...@@ -380,13 +378,16 @@ class LayerHelper(object):
def set_variable_initializer(self, var, initializer): def set_variable_initializer(self, var, initializer):
assert isinstance(var, Variable) assert isinstance(var, Variable)
self.startup_program.global_block().create_var( if imperative_base.enabled():
name=var.name, initializer(var, var.block)
type=var.type, else:
dtype=var.dtype, self.startup_program.global_block().create_var(
shape=var.shape, name=var.name,
persistable=True, type=var.type,
initializer=initializer) dtype=var.dtype,
shape=var.shape,
persistable=True,
initializer=initializer)
def append_bias_op(self, input_var, dim_start=1, dim_end=None): def append_bias_op(self, input_var, dim_start=1, dim_end=None):
""" """
......
...@@ -502,22 +502,22 @@ def lstm(input, ...@@ -502,22 +502,22 @@ def lstm(input,
If Device is GPU, This op will use cudnn LSTM implementation If Device is GPU, This op will use cudnn LSTM implementation
A four-gate Long Short-Term Memory network with no peephole connections. A four-gate Long Short-Term Memory network with no peephole connections.
In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1, In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1,
the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations: the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations:
.. math:: .. math::
i_t &= \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i) i_t &= \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i)
f_t &= \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f) f_t &= \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f)
o_t &= \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o) o_t &= \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o)
\\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c) \\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c)
c_t &= f_t \odot c_{t-1} + i_t \odot \\tilde{c_t} c_t &= f_t \odot c_{t-1} + i_t \odot \\tilde{c_t}
h_t &= o_t \odot tanh(c_t) h_t &= o_t \odot tanh(c_t)
- $W$ terms denote weight matrices (e.g. $W_{ix}$ is the matrix - $W$ terms denote weight matrices (e.g. $W_{ix}$ is the matrix
of weights from the input gate to the input) of weights from the input gate to the input)
...@@ -531,19 +531,19 @@ def lstm(input, ...@@ -531,19 +531,19 @@ def lstm(input,
- :math:`\\tilde{c_t}` is also called candidate hidden state, - :math:`\\tilde{c_t}` is also called candidate hidden state,
which is computed based on the current input and the previous hidden state. which is computed based on the current input and the previous hidden state.
Where sigmoid is the sigmoid operator: :math:`sigmoid(x) = 1 / (1 + e^{-x})` , * represents a point-wise multiplication, Where sigmoid is the sigmoid operator: :math:`sigmoid(x) = 1 / (1 + e^{-x})` , * represents a point-wise multiplication,
X represensts a matrix multiplication X represensts a matrix multiplication
Args: Args:
input (Variable): LSTM input tensor, shape MUST be ( seq_len x batch_size x input_size ) input (Variable): LSTM input tensor, shape MUST be ( seq_len x batch_size x input_size )
init_h(Variable): The initial hidden state of the LSTM init_h(Variable): The initial hidden state of the LSTM
This is a tensor with shape ( num_layers x batch_size x hidden_size) This is a tensor with shape ( num_layers x batch_size x hidden_size)
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size) if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
init_c(Variable): The initial cell state of the LSTM. init_c(Variable): The initial cell state of the LSTM.
This is a tensor with shape ( num_layers x batch_size x hidden_size ) This is a tensor with shape ( num_layers x batch_size x hidden_size )
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size) if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len max_len (int): max length of LSTM. the first dim of input tensor CAN NOT greater than max_len
hidden_size (int): hidden size of the LSTM hidden_size (int): hidden size of the LSTM
num_layers (int): total layers number of the LSTM num_layers (int): total layers number of the LSTM
dropout_prob(float|0.0): dropout prob, dropout ONLY work between rnn layers, NOT between time steps dropout_prob(float|0.0): dropout prob, dropout ONLY work between rnn layers, NOT between time steps
...@@ -558,18 +558,18 @@ def lstm(input, ...@@ -558,18 +558,18 @@ def lstm(input,
Returns: Returns:
rnn_out(Tensor),last_h(Tensor),last_c(Tensor): rnn_out(Tensor),last_h(Tensor),last_c(Tensor):
Three tensors, rnn_out, last_h, last_c: Three tensors, rnn_out, last_h, last_c:
- rnn_out is result of LSTM hidden, shape is (seq_len x batch_size x hidden_size) \ - rnn_out is result of LSTM hidden, shape is (seq_len x batch_size x hidden_size) \
if is_bidirec set to True, shape will be ( seq_len x batch_sze x hidden_size*2) if is_bidirec set to True, shape will be ( seq_len x batch_sze x hidden_size*2)
- last_h is the hidden state of the last step of LSTM \ - last_h is the hidden state of the last step of LSTM \
shape is ( num_layers x batch_size x hidden_size ) \ shape is ( num_layers x batch_size x hidden_size ) \
if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size) if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size)
- last_c(Tensor): the cell state of the last step of LSTM \ - last_c(Tensor): the cell state of the last step of LSTM \
shape is ( num_layers x batch_size x hidden_size ) \ shape is ( num_layers x batch_size x hidden_size ) \
if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size) if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size)
Examples: Examples:
...@@ -1255,7 +1255,7 @@ def dropout(x, ...@@ -1255,7 +1255,7 @@ def dropout(x,
(mask is a tensor same shape with input, value is 0 or 1 (mask is a tensor same shape with input, value is 0 or 1
ratio of 0 is dropout_prob) ratio of 0 is dropout_prob)
Returns: Returns:
Variable: A tensor variable is the shape with `x`. Variable: A tensor variable is the shape with `x`.
...@@ -1346,10 +1346,10 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex): ...@@ -1346,10 +1346,10 @@ def cross_entropy(input, label, soft_label=False, ignore_index=kIgnoreIndex):
ValueError: ValueError:
1. the 1st dimension of ``input`` and ``label`` are not equal. 1. the 1st dimension of ``input`` and ``label`` are not equal.
2. when ``soft_label == True``, and the 2nd dimension of 2. when ``soft_label == True``, and the 2nd dimension of
``input`` and ``label`` are not equal. ``input`` and ``label`` are not equal.
3. when ``soft_label == False``, and the 2nd dimension of 3. when ``soft_label == False``, and the 2nd dimension of
``label`` is not 1. ``label`` is not 1.
...@@ -1471,7 +1471,7 @@ def chunk_eval(input, ...@@ -1471,7 +1471,7 @@ def chunk_eval(input,
This function computes and outputs the precision, recall and This function computes and outputs the precision, recall and
F1-score of chunk detection. F1-score of chunk detection.
For some basics of chunking, please refer to For some basics of chunking, please refer to
`Chunking with Support Vector Machines <https://aclanthology.info/pdf/N/N01/N01-1025.pdf>`_ . `Chunking with Support Vector Machines <https://aclanthology.info/pdf/N/N01/N01-1025.pdf>`_ .
ChunkEvalOp computes the precision, recall, and F1-score of chunk detection, ChunkEvalOp computes the precision, recall, and F1-score of chunk detection,
...@@ -2306,7 +2306,7 @@ def sequence_slice(input, offset, length, name=None): ...@@ -2306,7 +2306,7 @@ def sequence_slice(input, offset, length, name=None):
out.lod = [[2, 1]], out.lod = [[2, 1]],
out.dims = (3, 2). out.dims = (3, 2).
Note: Note:
The first dimension size of **input**, **offset** and **length** The first dimension size of **input**, **offset** and **length**
should be equal. The **offset** should start from 0. should be equal. The **offset** should start from 0.
...@@ -2555,12 +2555,12 @@ def adaptive_pool2d(input, ...@@ -2555,12 +2555,12 @@ def adaptive_pool2d(input,
Examples: Examples:
.. code-block:: python .. code-block:: python
# suppose input data in shape of [N, C, H, W], `pool_size` is [m, n], # suppose input data in shape of [N, C, H, W], `pool_size` is [m, n],
# output shape is [N, C, m, n], adaptive pool divide H and W dimentions # output shape is [N, C, m, n], adaptive pool divide H and W dimentions
# of input data into m * n grids averagely and performs poolings in each # of input data into m * n grids averagely and performs poolings in each
# grid to get output. # grid to get output.
# adaptive average pool performs calculations as follow: # adaptive average pool performs calculations as follow:
# #
# for i in range(m): # for i in range(m):
# for j in range(n): # for j in range(n):
# hstart = floor(i * H / m) # hstart = floor(i * H / m)
...@@ -2649,10 +2649,10 @@ def adaptive_pool3d(input, ...@@ -2649,10 +2649,10 @@ def adaptive_pool3d(input,
# suppose input data in shape of [N, C, D, H, W], `pool_size` is [l, m, n], # suppose input data in shape of [N, C, D, H, W], `pool_size` is [l, m, n],
# output shape is [N, C, l, m, n], adaptive pool divide D, H and W dimentions # output shape is [N, C, l, m, n], adaptive pool divide D, H and W dimentions
# of input data into l * m * n grids averagely and performs poolings in each # of input data into l * m * n grids averagely and performs poolings in each
# grid to get output. # grid to get output.
# adaptive average pool performs calculations as follow: # adaptive average pool performs calculations as follow:
# #
# for i in range(l): # for i in range(l):
# for j in range(m): # for j in range(m):
# for k in range(n): # for k in range(n):
...@@ -2662,7 +2662,7 @@ def adaptive_pool3d(input, ...@@ -2662,7 +2662,7 @@ def adaptive_pool3d(input,
# hend = ceil((j + 1) * H / m) # hend = ceil((j + 1) * H / m)
# wstart = floor(k * W / n) # wstart = floor(k * W / n)
# wend = ceil((k + 1) * W / n) # wend = ceil((k + 1) * W / n)
# output[:, :, i, j, k] = # output[:, :, i, j, k] =
# avg(input[:, :, dstart:dend, hstart: hend, wstart: wend]) # avg(input[:, :, dstart:dend, hstart: hend, wstart: wend])
# #
data = fluid.layers.data( data = fluid.layers.data(
...@@ -4678,7 +4678,7 @@ def ctc_greedy_decoder(input, blank, name=None): ...@@ -4678,7 +4678,7 @@ def ctc_greedy_decoder(input, blank, name=None):
[0.5, 0.1, 0.3, 0.1]] [0.5, 0.1, 0.3, 0.1]]
input.lod = [[4, 4]] input.lod = [[4, 4]]
Computation: Computation:
step1: Apply argmax to first input sequence which is input.data[0:4]. Then we get: step1: Apply argmax to first input sequence which is input.data[0:4]. Then we get:
...@@ -4712,7 +4712,7 @@ def ctc_greedy_decoder(input, blank, name=None): ...@@ -4712,7 +4712,7 @@ def ctc_greedy_decoder(input, blank, name=None):
Variable: CTC greedy decode result which is a 2-D tensor with shape [Lp, 1]. \ Variable: CTC greedy decode result which is a 2-D tensor with shape [Lp, 1]. \
'Lp' is the sum if all output sequences' length. If all the sequences \ 'Lp' is the sum if all output sequences' length. If all the sequences \
in result were empty, the result LoDTensor will be [-1] with \ in result were empty, the result LoDTensor will be [-1] with \
LoD [[]] and dims [1, 1]. LoD [[]] and dims [1, 1].
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -5065,7 +5065,7 @@ def hsigmoid(input, ...@@ -5065,7 +5065,7 @@ def hsigmoid(input,
""" """
The hierarchical sigmoid operator is used to accelerate the training The hierarchical sigmoid operator is used to accelerate the training
process of language model. This operator organizes the classes into a process of language model. This operator organizes the classes into a
complete binary tree, or you can use is_custom to pass your own tree to complete binary tree, or you can use is_custom to pass your own tree to
implement hierarchical. Each leaf node represents a class(a word) and each implement hierarchical. Each leaf node represents a class(a word) and each
internal node acts as a binary classifier. For each word there's a unique internal node acts as a binary classifier. For each word there's a unique
path from root to it's leaf node, hsigmoid calculate the cost for each path from root to it's leaf node, hsigmoid calculate the cost for each
...@@ -5082,7 +5082,7 @@ def hsigmoid(input, ...@@ -5082,7 +5082,7 @@ def hsigmoid(input,
2. build a dict to store word_id -> word's leaf to root path, we call it path_table. 2. build a dict to store word_id -> word's leaf to root path, we call it path_table.
3. build a dict to store word_id -> code of word's leaf to root path, we call it path_code. Code 3. build a dict to store word_id -> code of word's leaf to root path, we call it path_code. Code
means label of each binary classification, using 1 indicate true, 0 indicate false. means label of each binary classification, using 1 indicate true, 0 indicate false.
4. now, each word should has its path and code along the path, you can pass a batch of path and code 4. now, each word should has its path and code along the path, you can pass a batch of path and code
related to the same batch of inputs. related to the same batch of inputs.
Args: Args:
...@@ -5091,8 +5091,8 @@ def hsigmoid(input, ...@@ -5091,8 +5091,8 @@ def hsigmoid(input,
and :math:`D` is the feature size. and :math:`D` is the feature size.
label (Variable): The tensor variable contains labels of training data. label (Variable): The tensor variable contains labels of training data.
It's a tensor with shape is :math:`[N \\times 1]`. It's a tensor with shape is :math:`[N \\times 1]`.
num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set, num_classes: (int), The number of classes, must not be less than 2. with default tree this has to be set,
it should never be None under is_custom=False, but while is_custom is true, it should be non leaf num it should never be None under is_custom=False, but while is_custom is true, it should be non leaf num
which indicates the num of classes using by binary classify. which indicates the num of classes using by binary classify.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid
...@@ -5105,15 +5105,15 @@ def hsigmoid(input, ...@@ -5105,15 +5105,15 @@ def hsigmoid(input,
is not set, the bias is initialized zero. Default: None. is not set, the bias is initialized zero. Default: None.
name (str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None. will be named automatically. Default: None.
path_table: (Variable|None) this variable can store each batch of samples' path to root, path_table: (Variable|None) this variable can store each batch of samples' path to root,
it should be in leaf -> root order it should be in leaf -> root order
path_table should have the same shape with path_code, and for each sample i path_table[i] indicates a np.array like path_table should have the same shape with path_code, and for each sample i path_table[i] indicates a np.array like
structure and each element in this array is indexes in parent nodes' Weight Matrix. structure and each element in this array is indexes in parent nodes' Weight Matrix.
path_code: (Variable|None) this variable can store each batch of samples' code, path_code: (Variable|None) this variable can store each batch of samples' code,
each code consist with every code of parent nodes. it should be in leaf -> root order each code consist with every code of parent nodes. it should be in leaf -> root order
is_custom: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is is_custom: (bool|False)using user defined binary tree instead of default complete binary tree, if costum is
set you need to set path_table/path_code/num_classes, otherwise num_classes should be set set you need to set path_table/path_code/num_classes, otherwise num_classes should be set
is_sparse: (bool|False)using sparse update instead of dense update, if set, the gradient is_sparse: (bool|False)using sparse update instead of dense update, if set, the gradient
of W and input will be sparse. of W and input will be sparse.
Returns: Returns:
...@@ -6965,10 +6965,10 @@ def mean_iou(input, label, num_classes): ...@@ -6965,10 +6965,10 @@ def mean_iou(input, label, num_classes):
num_classes (int): The possible number of labels. num_classes (int): The possible number of labels.
Returns: Returns:
mean_iou (Variable),out_wrong(Variable),out_correct(Variable): mean_iou (Variable),out_wrong(Variable),out_correct(Variable):
Three variables: Three variables:
- mean_iou : A Tensor representing the mean intersection-over-union with shape [1]. - mean_iou : A Tensor representing the mean intersection-over-union with shape [1].
- out_wrong: A Tensor with shape [num_classes]. The wrong numbers of each class. - out_wrong: A Tensor with shape [num_classes]. The wrong numbers of each class.
- out_correct: A Tensor with shape [num_classes]. The correct numbers of each class. - out_correct: A Tensor with shape [num_classes]. The correct numbers of each class.
...@@ -7166,7 +7166,7 @@ def affine_grid(theta, out_shape, name=None): ...@@ -7166,7 +7166,7 @@ def affine_grid(theta, out_shape, name=None):
Args: Args:
theta (Variable): A batch of affine transform parameters with shape [N, 2, 3]. theta (Variable): A batch of affine transform parameters with shape [N, 2, 3].
out_shape (Variable | list | tuple): The shape of target output with format [N, C, H, W]. out_shape (Variable | list | tuple): The shape of target output with format [N, C, H, W].
``out_shape`` can be a Variable or a list or tuple. ``out_shape`` can be a Variable or a list or tuple.
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
...@@ -7762,9 +7762,9 @@ def flatten(x, axis=1, name=None): ...@@ -7762,9 +7762,9 @@ def flatten(x, axis=1, name=None):
""" """
**Flatten layer** **Flatten layer**
Flattens the input tensor into a 2D matrix. Flattens the input tensor into a 2D matrix.
For Example: For Example:
.. code-block:: text .. code-block:: text
Case 1: Case 1:
...@@ -8942,7 +8942,7 @@ def similarity_focus(input, axis, indexes, name=None): ...@@ -8942,7 +8942,7 @@ def similarity_focus(input, axis, indexes, name=None):
SimilarityFocus Operator SimilarityFocus Operator
Generate a similarity focus mask with the same shape of input using the following method: Generate a similarity focus mask with the same shape of input using the following method:
1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding 1. Extract the 3-D tensor(here the first dimension is BatchSize) corresponding
to the axis according to the indexes. For example, if axis=1 and indexes=[a], to the axis according to the indexes. For example, if axis=1 and indexes=[a],
it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X it will get the matrix T=X[:, a, :, :]. In this case, if the shape of input X
...@@ -9403,7 +9403,7 @@ class PyFuncRegistry(object): ...@@ -9403,7 +9403,7 @@ class PyFuncRegistry(object):
raise TypeError('func must be a Python function') raise TypeError('func must be a Python function')
self._func = func self._func = func
# find named args using reflection # find named args using reflection
args = inspect.getargspec(self._func) args = inspect.getargspec(self._func)
if len(args[0]) == 0 and args[1] is None and args[2] is None: if len(args[0]) == 0 and args[1] is None and args[2] is None:
# Function with no inputs # Function with no inputs
...@@ -9414,15 +9414,15 @@ class PyFuncRegistry(object): ...@@ -9414,15 +9414,15 @@ class PyFuncRegistry(object):
''' '''
Why record self here? Why record self here?
1. For debug usage. Users can call 1. For debug usage. Users can call
:code:`py_func.registered_func(idx)` method :code:`py_func.registered_func(idx)` method
to find the registered function corresponding to find the registered function corresponding
to :code:`idx`. to :code:`idx`.
2. For increasing reference count of self. 2. For increasing reference count of self.
It seems that to release Python object It seems that to release Python object
whose reference count is 1 would cause whose reference count is 1 would cause
segmentation fault error in C++ side. segmentation fault error in C++ side.
May be lack of Python GC in C++ side? May be lack of Python GC in C++ side?
''' '''
PyFuncRegistry._register_funcs.append(self) PyFuncRegistry._register_funcs.append(self)
...@@ -9473,7 +9473,7 @@ class PyFuncRegistry(object): ...@@ -9473,7 +9473,7 @@ class PyFuncRegistry(object):
def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
""" """
PyFunc Operator. PyFunc Operator.
User can use :code:`py_func` to register operators in Python side. User can use :code:`py_func` to register operators in Python side.
The inputs of :code:`func` is :code:`LoDTensor` and outputs can be The inputs of :code:`func` is :code:`LoDTensor` and outputs can be
numpy array or :code:`LoDTensor`. Paddle would call the registered numpy array or :code:`LoDTensor`. Paddle would call the registered
...@@ -9491,7 +9491,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): ...@@ -9491,7 +9491,7 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
no gradient, users should return None. no gradient, users should return None.
This function can also be used to debug the running network. User can This function can also be used to debug the running network. User can
add a :code:`py_func` operator without output, and print input add a :code:`py_func` operator without output, and print input
:code:`x` inside :code:`func`. :code:`x` inside :code:`func`.
Args: Args:
...@@ -9499,50 +9499,50 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None): ...@@ -9499,50 +9499,50 @@ def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
x (Variable|list(Variable)|tuple(Variable)): inputs of :code:`func`. x (Variable|list(Variable)|tuple(Variable)): inputs of :code:`func`.
out (Variable|list(Variable)|tuple(Variable)): outputs of :code:`func`. out (Variable|list(Variable)|tuple(Variable)): outputs of :code:`func`.
Paddle cannot infer shapes and data types of :code:`out`. Users Paddle cannot infer shapes and data types of :code:`out`. Users
should create :code:`out` beforehand. should create :code:`out` beforehand.
backward_func (callable|None): backward Python function. backward_func (callable|None): backward Python function.
None means no backward. Default None. None means no backward. Default None.
skip_vars_in_backward_input (Variable|list(Variable)|tuple(Variable)): skip_vars_in_backward_input (Variable|list(Variable)|tuple(Variable)):
Variables that are not needed in :code:`backward_func` inputs. Variables that are not needed in :code:`backward_func` inputs.
These variables must be any of :code:`x` and :code:`out`. These variables must be any of :code:`x` and :code:`out`.
If set, these vars would not be inputs of :code:`backward_func`, If set, these vars would not be inputs of :code:`backward_func`,
Only useful when :code:`backward_func` is not None. Default None. Only useful when :code:`backward_func` is not None. Default None.
Returns: Returns:
out (Variable|list(Variable)|tuple(Variable)): input :code:`out` out (Variable|list(Variable)|tuple(Variable)): input :code:`out`
Examples: Examples:
>>> import paddle.fluid as fluid >>> import paddle.fluid as fluid
>>> import six >>> import six
>>> >>>
>>> def create_tmp_var(name, dtype, shape): >>> def create_tmp_var(name, dtype, shape):
>>> return fluid.default_main_program().current_block().create_var( >>> return fluid.default_main_program().current_block().create_var(
>>> name=name, dtype=dtype, shape=shape) >>> name=name, dtype=dtype, shape=shape)
>>> >>>
>>> # tanh activation has been provided by Paddle C++ op >>> # tanh activation has been provided by Paddle C++ op
>>> # Here, we only use tanh to be an example to show the usage >>> # Here, we only use tanh to be an example to show the usage
>>> # of py_func >>> # of py_func
>>> def tanh(x): >>> def tanh(x):
>>> return np.tanh(x) >>> return np.tanh(x)
>>> >>>
>>> # forward input x is skipped >>> # forward input x is skipped
>>> def tanh_grad(y, dy): >>> def tanh_grad(y, dy):
>>> return np.array(dy) * (1 - np.square(np.array(y))) >>> return np.array(dy) * (1 - np.square(np.array(y)))
>>> >>>
>>> def debug_func(x): >>> def debug_func(x):
>>> print(x) >>> print(x)
>>> >>>
>>> def simple_net(img, label): >>> def simple_net(img, label):
>>> hidden = img >>> hidden = img
>>> for idx in six.moves.range(4): >>> for idx in six.moves.range(4):
>>> hidden = fluid.layers.fc(hidden, size=200) >>> hidden = fluid.layers.fc(hidden, size=200)
>>> new_hidden = create_tmp_var(name='hidden_{}'.format(idx), >>> new_hidden = create_tmp_var(name='hidden_{}'.format(idx),
>>> dtype=hidden.dtype, shape=hidden.shape) >>> dtype=hidden.dtype, shape=hidden.shape)
>>> >>>
>>> # user-defined layers with forward and backward >>> # user-defined layers with forward and backward
>>> hidden = fluid.layers.py_func(func=tanh, x=hidden, >>> hidden = fluid.layers.py_func(func=tanh, x=hidden,
>>> out=new_hidden, backward_func=tanh_grad, >>> out=new_hidden, backward_func=tanh_grad,
>>> skip_vars_in_backward_input=hidden) >>> skip_vars_in_backward_input=hidden)
>>> >>>
>>> # user-defined debug layers to print variables >>> # user-defined debug layers to print variables
...@@ -9713,47 +9713,3 @@ def huber_loss(input, label, delta): ...@@ -9713,47 +9713,3 @@ def huber_loss(input, label, delta):
'Residual': residual}, 'Residual': residual},
attrs={'delta': delta}) attrs={'delta': delta})
return out return out
class FC(layers.PyLayer):
def __init__(self,
size,
param_attr=None,
num_flatten_dims=1,
dtype=core.VarDesc.VarType.FP32):
super(FC, self).__init__()
self._size = size
self._num_flatten_dims = num_flatten_dims
self._dtype = dtype
self._helper = LayerHelper('FC', param_attr=param_attr)
def _build_once(self, inputs):
input_shape = inputs[0].shape
param_shape = [
reduce(lambda a, b: a * b, input_shape[self._num_flatten_dims:], 1)
] + [self._size]
self._w = self._helper.create_parameter(
attr=self._helper.param_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=False)
def forward(self, inputs):
tmp = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="mul",
inputs={"X": inputs[0],
"Y": self._w},
outputs={"Out": tmp},
attrs={
"x_num_col_dims": self._num_flatten_dims,
"y_num_col_dims": 1
})
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="sum",
inputs={"X": [tmp]},
outputs={"Out": out},
attrs={"use_mkldnn": False})
return out
...@@ -20,6 +20,7 @@ from ..framework import convert_np_dtype_to_dtype_ ...@@ -20,6 +20,7 @@ from ..framework import convert_np_dtype_to_dtype_
from ..framework import Variable from ..framework import Variable
from ..initializer import Constant, force_init_on_cpu from ..initializer import Constant, force_init_on_cpu
from ..core import VarDesc from ..core import VarDesc
from ..imperative import base as imperative_base
from .layer_function_generator import templatedoc from .layer_function_generator import templatedoc
import numpy import numpy
...@@ -104,15 +105,15 @@ def create_global_var(shape, ...@@ -104,15 +105,15 @@ def create_global_var(shape,
Args: Args:
shape(list[int]): shape of the variable shape(list[int]): shape of the variable
value(float): the value of the variable. The new created value(float): the value of the variable. The new created
variable will be filled with it. variable will be filled with it.
dtype(string): data type of the variable dtype(string): data type of the variable
persistable(bool): if this variable is persistable. persistable(bool): if this variable is persistable.
Default: False Default: False
force_cpu(bool): force this variable to be on CPU. force_cpu(bool): force this variable to be on CPU.
Default: False Default: False
name(str|None): The name of the variable. If set to None the variable name(str|None): The name of the variable. If set to None the variable
name will be generated automatically. name will be generated automatically.
Default: None Default: None
Returns: Returns:
...@@ -121,21 +122,26 @@ def create_global_var(shape, ...@@ -121,21 +122,26 @@ def create_global_var(shape,
Examples: Examples:
.. code-block:: python .. code-block:: python
var = fluid.create_global_var(shape=[2,3], value=1.0, dtype='float32', var = fluid.create_global_var(shape=[2,3], value=1.0, dtype='float32',
persistable=True, force_cpu=True, name='new_var') persistable=True, force_cpu=True, name='new_var')
""" """
helper = LayerHelper("global_var", **locals()) helper = LayerHelper("global_var", **locals())
var = helper.create_global_variable( var = helper.create_global_variable(
dtype=dtype, shape=shape, persistable=persistable, name=name) dtype=dtype,
shape=shape,
persistable=persistable,
name=name,
stop_gradient=True)
helper.set_variable_initializer( helper.set_variable_initializer(
var, initializer=Constant( var, initializer=Constant(
value=float(value), force_cpu=force_cpu)) value=float(value), force_cpu=force_cpu))
return var return var
def cast(x, dtype): def cast(x, dtype):
""" """
This layer takes in the Variable :attr:`x` with :attr:`x.dtype` and casts This layer takes in the Variable :attr:`x` with :attr:`x.dtype` and casts
it to the output with :attr:`dtype`. it to the output with :attr:`dtype`.
Args: Args:
...@@ -199,9 +205,9 @@ def tensor_array_to_tensor(input, axis=1, name=None): ...@@ -199,9 +205,9 @@ def tensor_array_to_tensor(input, axis=1, name=None):
and returns that as the output. and returns that as the output.
A simple example as below: A simple example as below:
.. code-block:: text .. code-block:: text
Given: Given:
input.data = {[[0.6, 0.1, 0.3], input.data = {[[0.6, 0.1, 0.3],
...@@ -210,9 +216,9 @@ def tensor_array_to_tensor(input, axis=1, name=None): ...@@ -210,9 +216,9 @@ def tensor_array_to_tensor(input, axis=1, name=None):
[1.8]], [1.8]],
[[2.3, 2.1], [[2.3, 2.1],
[2.5, 2.4]]} [2.5, 2.4]]}
axis = 1 axis = 1
Then: Then:
output.data = [[0.6, 0.1, 0.3, 1.3, 2.3, 2.1], output.data = [[0.6, 0.1, 0.3, 1.3, 2.3, 2.1],
...@@ -498,12 +504,12 @@ def argmax(x, axis=0): ...@@ -498,12 +504,12 @@ def argmax(x, axis=0):
def argsort(input, axis=-1, name=None): def argsort(input, axis=-1, name=None):
""" """
Performs sorting on the input Variable along the given axis, and outputs Performs sorting on the input Variable along the given axis, and outputs
sorted data Varibale and its corresponding index Variable with the same sorted data Varibale and its corresponding index Variable with the same
shape as :attr:`input`. shape as :attr:`input`.
.. code-block:: text .. code-block:: text
For example, the given axis is -1 and the input Variable For example, the given axis is -1 and the input Variable
input = [[0.15849551, 0.45865775, 0.8563702 ], input = [[0.15849551, 0.45865775, 0.8563702 ],
...@@ -516,15 +522,15 @@ def argsort(input, axis=-1, name=None): ...@@ -516,15 +522,15 @@ def argsort(input, axis=-1, name=None):
and the sorted indices along the given axis turn outs to be and the sorted indices along the given axis turn outs to be
indices = [[0, 1, 2], indices = [[0, 1, 2],
[0, 2, 1]] [0, 2, 1]]
Args: Args:
input(Variable): The input Variable for sorting. input(Variable): The input Variable for sorting.
axis(int): The axis along which to sort the input Variable. When axis(int): The axis along which to sort the input Variable. When
:attr:`axis` < 0, the actual axis will be :attr:`axis` + :attr:`axis` < 0, the actual axis will be :attr:`axis` +
rank(:attr:`input`). Default -1, the last dimension. rank(:attr:`input`). Default -1, the last dimension.
name(str|None): (optional) A name for this layer. If set None, the name(str|None): (optional) A name for this layer. If set None, the
layer will be named automatically. layer will be named automatically.
Returns: Returns:
......
...@@ -30,6 +30,7 @@ from .initializer import Constant ...@@ -30,6 +30,7 @@ from .initializer import Constant
from .layer_helper import LayerHelper from .layer_helper import LayerHelper
from .layers import ops from .layers import ops
from .regularizer import append_regularization_ops from .regularizer import append_regularization_ops
from .imperative import base as imperative_base
__all__ = [ __all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl', 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl',
...@@ -301,25 +302,45 @@ class Optimizer(object): ...@@ -301,25 +302,45 @@ class Optimizer(object):
This method combines interface `append_backward()` and This method combines interface `append_backward()` and
`create_optimization_pass()` into one. `create_optimization_pass()` into one.
""" """
params_grads = append_backward(loss, parameter_list, no_grad_set, if imperative_base.enabled():
[error_clip_callback]) if parameter_list is not None:
params_grads = parameter_list
else:
program = loss.block.program
parameters = program.global_block().all_parameters()
params_grads = []
for param in parameters:
# create gradient variable
grad_var = Variable(
block=loss.block,
name=param._ivar._grad_name(),
stop_gradient=True)
grad_var._value = param._ivar.grad_value
params_grads.append((param, grad_var))
optimize_ops = self._create_optimization_pass(params_grads, loss,
startup_program)
else:
params_grads = append_backward(loss, parameter_list, no_grad_set,
[error_clip_callback])
params_grads = sorted(params_grads, key=lambda x: x[0].name)
params_grads = sorted(params_grads, key=lambda x: x[0].name) params_grads, table_param_and_grad, table_optimize_op = \
self._process_distribute_lookuptable(params_grads, loss, startup_program)
params_grads, table_param_and_grad, table_optimize_op = \ params_grads = append_gradient_clip_ops(params_grads)
self._process_distribute_lookuptable(params_grads, loss, startup_program)
params_grads = append_gradient_clip_ops(params_grads) # Add regularization if any
params_grads = append_regularization_ops(params_grads,
self.regularization)
# Add regularization if any optimize_ops = self._create_optimization_pass(params_grads, loss,
params_grads = append_regularization_ops(params_grads, startup_program)
self.regularization) if table_optimize_op is not None:
optimize_ops.append(table_optimize_op)
params_grads.append(table_param_and_grad)
optimize_ops = self._create_optimization_pass(params_grads, loss,
startup_program)
if table_optimize_op is not None:
optimize_ops.append(table_optimize_op)
params_grads.append(table_param_and_grad)
return optimize_ops, params_grads return optimize_ops, params_grads
...@@ -364,7 +385,8 @@ class SGDOptimizer(Optimizer): ...@@ -364,7 +385,8 @@ class SGDOptimizer(Optimizer):
"Grad": param_and_grad[1], "Grad": param_and_grad[1],
"LearningRate": self._create_param_lr(param_and_grad) "LearningRate": self._create_param_lr(param_and_grad)
}, },
outputs={"ParamOut": param_and_grad[0]}) outputs={"ParamOut": param_and_grad[0]},
stop_gradient=True)
return sgd_op return sgd_op
...@@ -448,7 +470,8 @@ class MomentumOptimizer(Optimizer): ...@@ -448,7 +470,8 @@ class MomentumOptimizer(Optimizer):
"VelocityOut": velocity_acc "VelocityOut": velocity_acc
}, },
attrs={"mu": self._momentum, attrs={"mu": self._momentum,
"use_nesterov": self._use_nesterov}) "use_nesterov": self._use_nesterov},
stop_gradient=True)
return momentum_op return momentum_op
...@@ -477,7 +500,7 @@ class LarsMomentumOptimizer(Optimizer): ...@@ -477,7 +500,7 @@ class LarsMomentumOptimizer(Optimizer):
regularization: A Regularizer, such as regularization: A Regularizer, such as
fluid.regularizer.L2DecayRegularizer. fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix. name: A optional name prefix.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -533,7 +556,8 @@ class LarsMomentumOptimizer(Optimizer): ...@@ -533,7 +556,8 @@ class LarsMomentumOptimizer(Optimizer):
"mu": self._momentum, "mu": self._momentum,
"lars_coeff": self._lars_coeff, "lars_coeff": self._lars_coeff,
"lars_weight_decay": self._lars_weight_decay "lars_weight_decay": self._lars_weight_decay
}) },
stop_gradient=True)
return momentum_op return momentum_op
...@@ -608,7 +632,8 @@ class AdagradOptimizer(Optimizer): ...@@ -608,7 +632,8 @@ class AdagradOptimizer(Optimizer):
}, },
outputs={"ParamOut": param_and_grad[0], outputs={"ParamOut": param_and_grad[0],
"MomentOut": moment_acc}, "MomentOut": moment_acc},
attrs={"epsilon": self._epsilon}) attrs={"epsilon": self._epsilon},
stop_gradient=True)
return adagrad_op return adagrad_op
...@@ -738,7 +763,8 @@ class AdamOptimizer(Optimizer): ...@@ -738,7 +763,8 @@ class AdamOptimizer(Optimizer):
"beta2": self._beta2, "beta2": self._beta2,
"epsilon": self._epsilon, "epsilon": self._epsilon,
"lazy_mode": self._lazy_mode "lazy_mode": self._lazy_mode
}) },
stop_gradient=True)
return adam_op return adam_op
...@@ -760,13 +786,15 @@ class AdamOptimizer(Optimizer): ...@@ -760,13 +786,15 @@ class AdamOptimizer(Optimizer):
type="scale", type="scale",
inputs={"X": beta1_pow_acc}, inputs={"X": beta1_pow_acc},
outputs={"Out": beta1_pow_acc}, outputs={"Out": beta1_pow_acc},
attrs={"scale": self._beta1}) attrs={"scale": self._beta1},
stop_gradient=True)
main_block.append_op( main_block.append_op(
type="scale", type="scale",
inputs={"X": beta2_pow_acc}, inputs={"X": beta2_pow_acc},
outputs={"Out": beta2_pow_acc}, outputs={"Out": beta2_pow_acc},
attrs={"scale": self._beta2}) attrs={"scale": self._beta2},
stop_gradient=True)
class AdamaxOptimizer(Optimizer): class AdamaxOptimizer(Optimizer):
...@@ -877,7 +905,8 @@ class AdamaxOptimizer(Optimizer): ...@@ -877,7 +905,8 @@ class AdamaxOptimizer(Optimizer):
"beta1": self._beta1, "beta1": self._beta1,
"beta2": self._beta2, "beta2": self._beta2,
"epsilon": self._epsilon "epsilon": self._epsilon
}) },
stop_gradient=True)
return adamax_op return adamax_op
...@@ -897,7 +926,8 @@ class AdamaxOptimizer(Optimizer): ...@@ -897,7 +926,8 @@ class AdamaxOptimizer(Optimizer):
type="scale", type="scale",
inputs={"X": beta1_pow_acc}, inputs={"X": beta1_pow_acc},
outputs={"Out": beta1_pow_acc}, outputs={"Out": beta1_pow_acc},
attrs={"scale": self._beta1}) attrs={"scale": self._beta1},
stop_gradient=True)
class DecayedAdagradOptimizer(Optimizer): class DecayedAdagradOptimizer(Optimizer):
...@@ -979,7 +1009,8 @@ class DecayedAdagradOptimizer(Optimizer): ...@@ -979,7 +1009,8 @@ class DecayedAdagradOptimizer(Optimizer):
}, },
outputs={"ParamOut": param_and_grad[0], outputs={"ParamOut": param_and_grad[0],
"MomentOut": moment_acc}, "MomentOut": moment_acc},
attrs={"epsilon": self._epsilon}) attrs={"epsilon": self._epsilon},
stop_gradient=True)
return decayed_adagrad_op return decayed_adagrad_op
...@@ -1075,7 +1106,8 @@ class AdadeltaOptimizer(Optimizer): ...@@ -1075,7 +1106,8 @@ class AdadeltaOptimizer(Optimizer):
"AvgSquaredUpdateOut": avg_squared_update_acc "AvgSquaredUpdateOut": avg_squared_update_acc
}, },
attrs={"epsilon": self._epsilon, attrs={"epsilon": self._epsilon,
"rho": self._rho}) "rho": self._rho},
stop_gradient=True)
return adadelta_op return adadelta_op
...@@ -1224,7 +1256,8 @@ class RMSPropOptimizer(Optimizer): ...@@ -1224,7 +1256,8 @@ class RMSPropOptimizer(Optimizer):
"decay": self._rho, "decay": self._rho,
"momentum": self._momentum, "momentum": self._momentum,
"centered": self._centered "centered": self._centered
}) },
stop_gradient=True)
return rmsprop_op return rmsprop_op
...@@ -1345,7 +1378,8 @@ class FtrlOptimizer(Optimizer): ...@@ -1345,7 +1378,8 @@ class FtrlOptimizer(Optimizer):
}, },
attrs={"l1": self._l1, attrs={"l1": self._l1,
"l2": self._l1, "l2": self._l1,
"lr_power": self._lr_power}) "lr_power": self._lr_power},
stop_gradient=True)
return ftrl_op return ftrl_op
...@@ -1509,7 +1543,8 @@ class ModelAverage(Optimizer): ...@@ -1509,7 +1543,8 @@ class ModelAverage(Optimizer):
"average_window": self.average_window, "average_window": self.average_window,
"min_average_window": self.min_average_window, "min_average_window": self.min_average_window,
"max_average_window": self.max_average_window, "max_average_window": self.max_average_window,
}) },
stop_gradient=True)
@contextmanager @contextmanager
def apply(self, executor, need_restore=True): def apply(self, executor, need_restore=True):
......
...@@ -18,17 +18,8 @@ import numpy as np ...@@ -18,17 +18,8 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.layers.nn import FC from paddle.fluid.imperative.nn import FC
from test_imperative_base import new_program_scope
@contextlib.contextmanager
def new_program_scope():
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
class MyLayer(fluid.imperative.PyLayer): class MyLayer(fluid.imperative.PyLayer):
...@@ -36,7 +27,7 @@ class MyLayer(fluid.imperative.PyLayer): ...@@ -36,7 +27,7 @@ class MyLayer(fluid.imperative.PyLayer):
super(MyLayer, self).__init__() super(MyLayer, self).__init__()
def forward(self, inputs): def forward(self, inputs):
x = fluid.layers.relu(inputs[0]) x = fluid.layers.relu(inputs)
self._x_for_debug = x self._x_for_debug = x
x = fluid.layers.elementwise_mul(x, x) x = fluid.layers.elementwise_mul(x, x)
x = fluid.layers.reduce_sum(x) x = fluid.layers.reduce_sum(x)
...@@ -54,7 +45,7 @@ class MLP(fluid.imperative.PyLayer): ...@@ -54,7 +45,7 @@ class MLP(fluid.imperative.PyLayer):
initializer=fluid.initializer.Constant(value=0.1))) initializer=fluid.initializer.Constant(value=0.1)))
def forward(self, inputs): def forward(self, inputs):
x = self._fc1(inputs[0]) x = self._fc1(inputs)
x = self._fc2(x) x = self._fc2(x)
x = fluid.layers.reduce_sum(x) x = fluid.layers.reduce_sum(x)
return x return x
...@@ -66,13 +57,14 @@ class TestImperative(unittest.TestCase): ...@@ -66,13 +57,14 @@ class TestImperative(unittest.TestCase):
cl = core.Layer() cl = core.Layer()
cl.forward([]) cl.forward([])
l = fluid.imperative.PyLayer() l = fluid.imperative.PyLayer()
l.forward([]) self.assertRaises(NotImplementedError, l.forward, [])
def test_layer_in_out(self): def test_layer_in_out(self):
np_inp = np.array([1.0, 2.0, -1.0], dtype=np.float32) np_inp = np.array([1.0, 2.0, -1.0], dtype=np.float32)
with fluid.imperative.guard(): with fluid.imperative.guard():
var_inp = fluid.imperative.base.to_variable(np_inp)
l = MyLayer() l = MyLayer()
x = l(np_inp)[0] x = l(var_inp)[0]
self.assertIsNotNone(x) self.assertIsNotNone(x)
dy_out = x._numpy() dy_out = x._numpy()
x._backward() x._backward()
...@@ -97,8 +89,9 @@ class TestImperative(unittest.TestCase): ...@@ -97,8 +89,9 @@ class TestImperative(unittest.TestCase):
def test_mlp(self): def test_mlp(self):
np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
with fluid.imperative.guard(): with fluid.imperative.guard():
var_inp = fluid.imperative.base.to_variable(np_inp)
mlp = MLP() mlp = MLP()
out = mlp(np_inp) out = mlp(var_inp)
dy_out = out._numpy() dy_out = out._numpy()
out._backward() out._backward()
dy_grad = mlp._fc1._w._gradient() dy_grad = mlp._fc1._w._gradient()
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import core
@contextlib.contextmanager
def new_program_scope():
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import unittest
import numpy as np
import six
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.imperative.nn import Conv2D, Pool2D, FC
from paddle.fluid.imperative.base import to_variable
from test_imperative_base import new_program_scope
class SimpleImgConvPool(fluid.imperative.PyLayer):
def __init__(self,
num_channels,
num_filters,
filter_size,
pool_size,
pool_stride,
pool_padding=0,
pool_type='max',
global_pooling=False,
conv_stride=1,
conv_padding=0,
conv_dilation=1,
conv_groups=1,
act=None,
use_cudnn=False,
param_attr=None,
bias_attr=None):
super(SimpleImgConvPool, self).__init__()
self._conv2d = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
padding=conv_padding,
dilation=conv_dilation,
groups=conv_groups,
param_attr=None,
bias_attr=None,
use_cudnn=use_cudnn)
self._pool2d = Pool2D(
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride,
pool_padding=pool_padding,
global_pooling=global_pooling,
use_cudnn=use_cudnn)
def forward(self, inputs):
x = self._conv2d(inputs)
x = self._pool2d(x)
return x
class MNIST(fluid.imperative.PyLayer):
def __init__(self, param_attr=None, bias_attr=None):
super(MNIST, self).__init__()
self._simple_img_conv_pool_1 = SimpleImgConvPool(
1, 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool(
20, 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 8 * 8
SIZE = 10
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
self._fc = FC(10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)))
def forward(self, inputs):
x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x)
x = self._fc(x)
return x
class TestImperativeMnist(unittest.TestCase):
def test_mnist_cpu_float32(self):
seed = 90
with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
# mnist = Conv2D(1, 20, 5)
mnist = MNIST()
sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128)
dy_param_init_value = {}
for batch_id, data in enumerate(train_reader()):
if batch_id >= 2:
break
x_data = np.array(
[x[0].reshape(1, 28, 28) for x in data]).astype('float32')
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
128, 1)
img = to_variable(x_data)
label = to_variable(y_data)
label._stop_gradient = True
cost = mnist(img)
loss = fluid.layers.reduce_mean(cost)
dy_out = loss._numpy()
if batch_id == 0:
for param in fluid.default_main_program().global_block(
).all_parameters():
dy_param_init_value[param.name] = param._numpy()
loss._backward()
sgd.minimize(loss)
dy_param_value = {}
for param in fluid.default_main_program().global_block(
).all_parameters():
dy_param_value[param.name] = param._numpy()
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
exe = fluid.Executor(fluid.CPUPlace())
# mnist = Conv2D(1, 20, 5)
mnist = MNIST()
sgd = SGDOptimizer(learning_rate=1e-3)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=128)
img = fluid.layers.data(
name='pixel', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
cost = mnist(img)
loss = fluid.layers.reduce_mean(cost)
sgd.minimize(loss)
# initialize params and fetch them
static_param_init_value = {}
static_param_name_list = []
for param in fluid.default_startup_program().global_block(
).all_parameters():
static_param_name_list.append(param.name)
out = exe.run(fluid.default_startup_program(),
fetch_list=static_param_name_list)
for i in range(len(static_param_name_list)):
static_param_init_value[static_param_name_list[i]] = out[i]
for batch_id, data in enumerate(train_reader()):
if batch_id >= 2:
break
x_data = np.array(
[x[0].reshape(1, 28, 28) for x in data]).astype('float32')
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
[128, 1])
fetch_list = [loss.name]
fetch_list.extend(static_param_name_list)
out = exe.run(fluid.default_main_program(),
feed={"pixel": x_data,
"label": y_data},
fetch_list=fetch_list)
static_param_value = {}
static_out = out[0]
for i in range(1, len(out)):
static_param_value[static_param_name_list[i - 1]] = out[i]
for key, value in six.iteritems(static_param_init_value):
self.assertTrue(
np.allclose(value.all(), dy_param_init_value[key].all()))
self.assertTrue(np.allclose(static_out.all(), dy_out.all()))
for key, value in six.iteritems(static_param_value):
self.assertTrue(np.allclose(value.all(), dy_param_value[key].all()))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册