提交 2547f9d1 编写于 作者: M minqiyang

Polish code

test=develop
上级 09e2e662
......@@ -69,7 +69,7 @@ inline std::string GradVarName(const std::string& var_name) {
return result;
}
inline std::string OriginVarName(const std::string& grad_var_name) {
inline std::string GradOriginalVarName(const std::string& grad_var_name) {
std::size_t pos = grad_var_name.rfind(kGradVarSuffix);
if (pos == std::string::npos) {
return grad_var_name;
......
......@@ -294,24 +294,24 @@ TEST(VarNameTest, all) {
std::string grad_var_name = paddle::framework::GradVarName(var_name);
ASSERT_EQ(grad_var_name, "X@GRAD");
std::string original_var_name =
paddle::framework::OriginVarName(grad_var_name);
paddle::framework::GradOriginalVarName(grad_var_name);
ASSERT_EQ(original_var_name, "X");
original_var_name = paddle::framework::OriginVarName(original_var_name);
original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
ASSERT_EQ(original_var_name, "X");
std::string var_name_2("XYZ");
grad_var_name = paddle::framework::GradVarName(var_name_2);
ASSERT_EQ(grad_var_name, "XYZ@GRAD");
original_var_name = paddle::framework::OriginVarName(grad_var_name);
original_var_name = paddle::framework::GradOriginalVarName(grad_var_name);
ASSERT_EQ(original_var_name, "XYZ");
original_var_name = paddle::framework::OriginVarName(original_var_name);
original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
ASSERT_EQ(original_var_name, "XYZ");
std::string var_name_3("");
grad_var_name = paddle::framework::GradVarName(var_name_3);
ASSERT_EQ(grad_var_name, "@GRAD");
original_var_name = paddle::framework::OriginVarName(grad_var_name);
original_var_name = paddle::framework::GradOriginalVarName(grad_var_name);
ASSERT_EQ(original_var_name, "");
original_var_name = paddle::framework::OriginVarName(original_var_name);
original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
ASSERT_EQ(original_var_name, "");
}
......@@ -32,6 +32,11 @@ using framework::Variable;
void AddTo(Variable* src, Variable* dst) {
framework::LoDTensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
framework::LoDTensor* src_tensor = src->GetMutable<framework::LoDTensor>();
// FIXME(minqiyang): loss_grad op will pass a zero grad of label
// ugly fix for it
if (src_tensor->numel() == 0) {
return;
}
PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
"dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
src_tensor->numel());
......@@ -157,15 +162,11 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
auto& outputs = grad_outputs[it.first];
auto& origin_outputs = it.second;
auto& forward_inputs = input_vars_[framework::OriginVarName(it.first)];
for (size_t i = 0; i < outputs.size(); ++i) {
if (!forward_inputs[i]->stop_gradient_) {
framework::Variable* orig_grad = origin_outputs[i];
AddTo(outputs[i], orig_grad);
}
}
}
return input_vars_;
}
......
......@@ -81,7 +81,15 @@ class OpBase;
class VarBase {
public:
explicit VarBase(bool stop_gradient = false)
VarBase()
: pre_op_(nullptr),
pre_op_out_idx_(-1),
var_desc_(nullptr),
var_(new framework::Variable()),
grads_(new framework::Variable()),
stop_gradient_(false) {}
explicit VarBase(bool stop_gradient)
: pre_op_(nullptr),
pre_op_out_idx_(-1),
var_desc_(nullptr),
......@@ -89,23 +97,12 @@ class VarBase {
grads_(new framework::Variable()),
stop_gradient_(stop_gradient) {}
virtual ~VarBase() {
if (var_) {
delete var_;
var_ = nullptr;
}
if (grads_) {
delete grads_;
grads_ = nullptr;
}
}
virtual ~VarBase() {}
void RunBackward();
framework::LoDTensor& Grad();
inline framework::Variable* GradVar() { return grads_; }
inline std::string GradName() const {
PADDLE_ENFORCE(
var_desc_,
......
......@@ -57,7 +57,7 @@ class Tracer {
void Trace(OpBase* op,
const std::map<std::string, std::vector<VarBase*>>& inputs,
const std::map<std::string, std::vector<VarBase*>>& outputs,
framework::BlockDesc* block, const bool stop_gradient) {
framework::BlockDesc* block, const bool stop_gradient = false) {
std::map<std::string, VarBase*> vars;
framework::OpDesc* op_desc = op->op_desc_;
......@@ -153,6 +153,7 @@ class Tracer {
}
}
}
for (auto it : grad_op_desc->Outputs()) {
auto& grad_out_vars = op->grad_output_vars_[it.first];
for (const std::string& grad_outvar : it.second) {
......
......@@ -125,7 +125,8 @@ PYBIND11_MODULE(core, m) {
m.add_object("_cleanup",
py::capsule([]() { ScopePool::Instance().Clear(); }));
py::class_<imperative::VarBase, PyVarBase>(m, "VarBase", R"DOC()DOC")
py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
m, "VarBase", R"DOC()DOC")
// .def(py::init<>())
.def(py::init<bool>(), py::arg("stop_gradient") = false)
.def("_run_backward",
......
......@@ -24,20 +24,29 @@ __all__ = ['PyLayer']
class PyLayer(core.Layer):
def __init__(self, *args, **kwargs):
self._once_built = True
def __init__(self,
dtype=core.VarDesc.VarType.FP32,
param_attr=None,
bias_attr=None,
name=None):
from ..layer_helper import LayerHelper
self._helper = LayerHelper(type(self).__name__, **kwargs)
self._dtype = kwargs.get("dtype", core.VarDesc.VarType.FP32)
self._helper = LayerHelper(
type(self).__name__,
param_attr=param_attr,
bias_attr=bias_attr,
dtype=dtype,
name=name)
self._once_built = False
self._dtype = dtype
def _build_once(self, inputs):
pass
def __call__(self, *inputs):
if self._once_built:
if not self._once_built:
self._build_once(*inputs)
self._once_built = False
self._once_built = True
outputs = self.forward(*inputs)
......
......@@ -314,11 +314,9 @@ class LayerHelper(object):
WeightNormParamAttr.params_with_weight_norm.append(param)
return param
if _in_imperative_mode():
self.main_program.global_block().create_parameter(
dtype=dtype, shape=shape, **attr._to_kwargs())
# In imperative mode, we want the returned parameter to be
# initialized so that it can be used imperatively.
return self.startup_program.global_block().create_parameter(
return self.main_program.global_block().create_parameter(
dtype=dtype,
shape=shape,
**attr._to_kwargs(with_initializer=True))
......
......@@ -111,8 +111,6 @@ class TestImperativeMnist(unittest.TestCase):
predict = mnist(img)
out = fluid.layers.cross_entropy(predict, label)
out._backward()
filter_grad = mnist._simple_img_conv_pool_1._conv2d._filter_param._gradient(
)
sgd.minimize(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册