提交 2547f9d1 编写于 作者: M minqiyang

Polish code

test=develop
上级 09e2e662
...@@ -69,7 +69,7 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -69,7 +69,7 @@ inline std::string GradVarName(const std::string& var_name) {
return result; return result;
} }
inline std::string OriginVarName(const std::string& grad_var_name) { inline std::string GradOriginalVarName(const std::string& grad_var_name) {
std::size_t pos = grad_var_name.rfind(kGradVarSuffix); std::size_t pos = grad_var_name.rfind(kGradVarSuffix);
if (pos == std::string::npos) { if (pos == std::string::npos) {
return grad_var_name; return grad_var_name;
......
...@@ -294,24 +294,24 @@ TEST(VarNameTest, all) { ...@@ -294,24 +294,24 @@ TEST(VarNameTest, all) {
std::string grad_var_name = paddle::framework::GradVarName(var_name); std::string grad_var_name = paddle::framework::GradVarName(var_name);
ASSERT_EQ(grad_var_name, "X@GRAD"); ASSERT_EQ(grad_var_name, "X@GRAD");
std::string original_var_name = std::string original_var_name =
paddle::framework::OriginVarName(grad_var_name); paddle::framework::GradOriginalVarName(grad_var_name);
ASSERT_EQ(original_var_name, "X"); ASSERT_EQ(original_var_name, "X");
original_var_name = paddle::framework::OriginVarName(original_var_name); original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
ASSERT_EQ(original_var_name, "X"); ASSERT_EQ(original_var_name, "X");
std::string var_name_2("XYZ"); std::string var_name_2("XYZ");
grad_var_name = paddle::framework::GradVarName(var_name_2); grad_var_name = paddle::framework::GradVarName(var_name_2);
ASSERT_EQ(grad_var_name, "XYZ@GRAD"); ASSERT_EQ(grad_var_name, "XYZ@GRAD");
original_var_name = paddle::framework::OriginVarName(grad_var_name); original_var_name = paddle::framework::GradOriginalVarName(grad_var_name);
ASSERT_EQ(original_var_name, "XYZ"); ASSERT_EQ(original_var_name, "XYZ");
original_var_name = paddle::framework::OriginVarName(original_var_name); original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
ASSERT_EQ(original_var_name, "XYZ"); ASSERT_EQ(original_var_name, "XYZ");
std::string var_name_3(""); std::string var_name_3("");
grad_var_name = paddle::framework::GradVarName(var_name_3); grad_var_name = paddle::framework::GradVarName(var_name_3);
ASSERT_EQ(grad_var_name, "@GRAD"); ASSERT_EQ(grad_var_name, "@GRAD");
original_var_name = paddle::framework::OriginVarName(grad_var_name); original_var_name = paddle::framework::GradOriginalVarName(grad_var_name);
ASSERT_EQ(original_var_name, ""); ASSERT_EQ(original_var_name, "");
original_var_name = paddle::framework::OriginVarName(original_var_name); original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
ASSERT_EQ(original_var_name, ""); ASSERT_EQ(original_var_name, "");
} }
...@@ -32,6 +32,11 @@ using framework::Variable; ...@@ -32,6 +32,11 @@ using framework::Variable;
void AddTo(Variable* src, Variable* dst) { void AddTo(Variable* src, Variable* dst) {
framework::LoDTensor* dst_tensor = dst->GetMutable<framework::LoDTensor>(); framework::LoDTensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
framework::LoDTensor* src_tensor = src->GetMutable<framework::LoDTensor>(); framework::LoDTensor* src_tensor = src->GetMutable<framework::LoDTensor>();
// FIXME(minqiyang): loss_grad op will pass a zero grad of label
// ugly fix for it
if (src_tensor->numel() == 0) {
return;
}
PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(), PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
"dst_numel %lld vs. src_numel %lld", dst_tensor->numel(), "dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
src_tensor->numel()); src_tensor->numel());
...@@ -157,13 +162,9 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() { ...@@ -157,13 +162,9 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
auto& outputs = grad_outputs[it.first]; auto& outputs = grad_outputs[it.first];
auto& origin_outputs = it.second; auto& origin_outputs = it.second;
auto& forward_inputs = input_vars_[framework::OriginVarName(it.first)];
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
if (!forward_inputs[i]->stop_gradient_) { framework::Variable* orig_grad = origin_outputs[i];
framework::Variable* orig_grad = origin_outputs[i]; AddTo(outputs[i], orig_grad);
AddTo(outputs[i], orig_grad);
}
} }
} }
return input_vars_; return input_vars_;
......
...@@ -81,7 +81,15 @@ class OpBase; ...@@ -81,7 +81,15 @@ class OpBase;
class VarBase { class VarBase {
public: public:
explicit VarBase(bool stop_gradient = false) VarBase()
: pre_op_(nullptr),
pre_op_out_idx_(-1),
var_desc_(nullptr),
var_(new framework::Variable()),
grads_(new framework::Variable()),
stop_gradient_(false) {}
explicit VarBase(bool stop_gradient)
: pre_op_(nullptr), : pre_op_(nullptr),
pre_op_out_idx_(-1), pre_op_out_idx_(-1),
var_desc_(nullptr), var_desc_(nullptr),
...@@ -89,23 +97,12 @@ class VarBase { ...@@ -89,23 +97,12 @@ class VarBase {
grads_(new framework::Variable()), grads_(new framework::Variable()),
stop_gradient_(stop_gradient) {} stop_gradient_(stop_gradient) {}
virtual ~VarBase() { virtual ~VarBase() {}
if (var_) {
delete var_;
var_ = nullptr;
}
if (grads_) {
delete grads_;
grads_ = nullptr;
}
}
void RunBackward(); void RunBackward();
framework::LoDTensor& Grad(); framework::LoDTensor& Grad();
inline framework::Variable* GradVar() { return grads_; }
inline std::string GradName() const { inline std::string GradName() const {
PADDLE_ENFORCE( PADDLE_ENFORCE(
var_desc_, var_desc_,
......
...@@ -57,7 +57,7 @@ class Tracer { ...@@ -57,7 +57,7 @@ class Tracer {
void Trace(OpBase* op, void Trace(OpBase* op,
const std::map<std::string, std::vector<VarBase*>>& inputs, const std::map<std::string, std::vector<VarBase*>>& inputs,
const std::map<std::string, std::vector<VarBase*>>& outputs, const std::map<std::string, std::vector<VarBase*>>& outputs,
framework::BlockDesc* block, const bool stop_gradient) { framework::BlockDesc* block, const bool stop_gradient = false) {
std::map<std::string, VarBase*> vars; std::map<std::string, VarBase*> vars;
framework::OpDesc* op_desc = op->op_desc_; framework::OpDesc* op_desc = op->op_desc_;
...@@ -153,6 +153,7 @@ class Tracer { ...@@ -153,6 +153,7 @@ class Tracer {
} }
} }
} }
for (auto it : grad_op_desc->Outputs()) { for (auto it : grad_op_desc->Outputs()) {
auto& grad_out_vars = op->grad_output_vars_[it.first]; auto& grad_out_vars = op->grad_output_vars_[it.first];
for (const std::string& grad_outvar : it.second) { for (const std::string& grad_outvar : it.second) {
......
...@@ -125,7 +125,8 @@ PYBIND11_MODULE(core, m) { ...@@ -125,7 +125,8 @@ PYBIND11_MODULE(core, m) {
m.add_object("_cleanup", m.add_object("_cleanup",
py::capsule([]() { ScopePool::Instance().Clear(); })); py::capsule([]() { ScopePool::Instance().Clear(); }));
py::class_<imperative::VarBase, PyVarBase>(m, "VarBase", R"DOC()DOC") py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
m, "VarBase", R"DOC()DOC")
// .def(py::init<>()) // .def(py::init<>())
.def(py::init<bool>(), py::arg("stop_gradient") = false) .def(py::init<bool>(), py::arg("stop_gradient") = false)
.def("_run_backward", .def("_run_backward",
......
...@@ -24,20 +24,29 @@ __all__ = ['PyLayer'] ...@@ -24,20 +24,29 @@ __all__ = ['PyLayer']
class PyLayer(core.Layer): class PyLayer(core.Layer):
def __init__(self, *args, **kwargs): def __init__(self,
self._once_built = True dtype=core.VarDesc.VarType.FP32,
param_attr=None,
bias_attr=None,
name=None):
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
self._helper = LayerHelper(type(self).__name__, **kwargs) self._helper = LayerHelper(
self._dtype = kwargs.get("dtype", core.VarDesc.VarType.FP32) type(self).__name__,
param_attr=param_attr,
bias_attr=bias_attr,
dtype=dtype,
name=name)
self._once_built = False
self._dtype = dtype
def _build_once(self, inputs): def _build_once(self, inputs):
pass pass
def __call__(self, *inputs): def __call__(self, *inputs):
if self._once_built: if not self._once_built:
self._build_once(*inputs) self._build_once(*inputs)
self._once_built = False self._once_built = True
outputs = self.forward(*inputs) outputs = self.forward(*inputs)
......
...@@ -314,11 +314,9 @@ class LayerHelper(object): ...@@ -314,11 +314,9 @@ class LayerHelper(object):
WeightNormParamAttr.params_with_weight_norm.append(param) WeightNormParamAttr.params_with_weight_norm.append(param)
return param return param
if _in_imperative_mode(): if _in_imperative_mode():
self.main_program.global_block().create_parameter(
dtype=dtype, shape=shape, **attr._to_kwargs())
# In imperative mode, we want the returned parameter to be # In imperative mode, we want the returned parameter to be
# initialized so that it can be used imperatively. # initialized so that it can be used imperatively.
return self.startup_program.global_block().create_parameter( return self.main_program.global_block().create_parameter(
dtype=dtype, dtype=dtype,
shape=shape, shape=shape,
**attr._to_kwargs(with_initializer=True)) **attr._to_kwargs(with_initializer=True))
......
...@@ -111,8 +111,6 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -111,8 +111,6 @@ class TestImperativeMnist(unittest.TestCase):
predict = mnist(img) predict = mnist(img)
out = fluid.layers.cross_entropy(predict, label) out = fluid.layers.cross_entropy(predict, label)
out._backward() out._backward()
filter_grad = mnist._simple_img_conv_pool_1._conv2d._filter_param._gradient(
)
sgd.minimize(out) sgd.minimize(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册