提交 47329f6b 编写于 作者: X xzl

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_dilation

...@@ -28,3 +28,4 @@ cmake_install.cmake ...@@ -28,3 +28,4 @@ cmake_install.cmake
paddle/.timestamp paddle/.timestamp
python/paddlepaddle.egg-info/ python/paddlepaddle.egg-info/
paddle/pybind/pybind.h paddle/pybind/pybind.h
python/paddle/v2/framework/tests/tmp/*
...@@ -15,7 +15,7 @@ nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) ...@@ -15,7 +15,7 @@ nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc) cc_library(scope SRCS scope.cc DEPS glog)
cc_test(scope_test SRCS scope_test.cc DEPS scope) cc_test(scope_test SRCS scope_test.cc DEPS scope)
......
...@@ -107,6 +107,8 @@ class OpDescBind { ...@@ -107,6 +107,8 @@ class OpDescBind {
void InferVarType(BlockDescBind *block) const; void InferVarType(BlockDescBind *block) const;
void MarkAsTarget() { desc_.set_is_target(true); }
void Flush(); void Flush();
private: private:
......
...@@ -49,6 +49,13 @@ ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) { ...@@ -49,6 +49,13 @@ ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
} }
} }
ProgramDescBind::ProgramDescBind(const ProgramDesc &desc) {
desc_ = desc;
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc));
}
}
ProgramDescBind::ProgramDescBind(const std::string &binary_str) { ProgramDescBind::ProgramDescBind(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str), PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string."); "Fail to parse program_desc from binary string.");
......
...@@ -29,6 +29,8 @@ class ProgramDescBind { ...@@ -29,6 +29,8 @@ class ProgramDescBind {
public: public:
ProgramDescBind(); ProgramDescBind();
explicit ProgramDescBind(const ProgramDesc &desc);
ProgramDescBind(const ProgramDescBind &o); ProgramDescBind(const ProgramDescBind &o);
explicit ProgramDescBind(const std::string &binary_str); explicit ProgramDescBind(const std::string &binary_str);
......
...@@ -46,7 +46,7 @@ bool IsTarget(const OpDesc& op_desc) { ...@@ -46,7 +46,7 @@ bool IsTarget(const OpDesc& op_desc) {
return false; return false;
} }
void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) { void prune_impl(const ProgramDesc& input, ProgramDesc* output, int block_id) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
// - will change to use multiple blocks for RNN op and Cond Op // - will change to use multiple blocks for RNN op and Cond Op
...@@ -91,8 +91,8 @@ void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) { ...@@ -91,8 +91,8 @@ void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
// we reverse the should_run vector // we reverse the should_run vector
std::reverse(should_run.begin(), should_run.end()); std::reverse(should_run.begin(), should_run.end());
output = input; *output = input;
auto* op_field = output.mutable_blocks(block_id)->mutable_ops(); auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
op_field->Clear(); op_field->Clear();
for (size_t i = 0; i < should_run.size(); ++i) { for (size_t i = 0; i < should_run.size(); ++i) {
if (should_run[i]) { if (should_run[i]) {
...@@ -101,7 +101,8 @@ void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) { ...@@ -101,7 +101,8 @@ void prune_impl(const ProgramDesc& input, ProgramDesc& output, int block_id) {
} }
} }
void Prune(const ProgramDesc& input, ProgramDesc& output) { // TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
void Prune(const ProgramDesc& input, ProgramDesc* output) {
prune_impl(input, output, 0); prune_impl(input, output, 0);
} }
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void Prune(const ProgramDesc& input, ProgramDesc& output); void Prune(const ProgramDesc& input, ProgramDesc* output);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -59,11 +59,11 @@ TEST(Prune, one_operator) { ...@@ -59,11 +59,11 @@ TEST(Prune, one_operator) {
f::ProgramDesc *pdesc = program.Proto(); f::ProgramDesc *pdesc = program.Proto();
f::ProgramDesc pruned; f::ProgramDesc pruned;
Prune(*pdesc, pruned); Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 0);
pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(0)->set_is_target(true);
Prune(*pdesc, pruned); Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 1);
} }
...@@ -81,7 +81,7 @@ TEST(Prune, forward) { ...@@ -81,7 +81,7 @@ TEST(Prune, forward) {
for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) { for (int i = 0; i < pdesc->blocks(0).ops_size(); ++i) {
f::ProgramDesc pruned; f::ProgramDesc pruned;
pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(i)->set_is_target(true);
Prune(*pdesc, pruned); Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), i + 1);
} }
} }
...@@ -100,7 +100,7 @@ TEST(Prune, multi_input_op) { ...@@ -100,7 +100,7 @@ TEST(Prune, multi_input_op) {
pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(3)->set_is_target(true);
f::ProgramDesc pruned; f::ProgramDesc pruned;
Prune(*pdesc, pruned); Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 4);
} }
...@@ -116,7 +116,7 @@ TEST(Prune, multi_output_op) { ...@@ -116,7 +116,7 @@ TEST(Prune, multi_output_op) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned; f::ProgramDesc pruned;
Prune(*pdesc, pruned); Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 2);
} }
...@@ -133,6 +133,6 @@ TEST(Prune, multi_target) { ...@@ -133,6 +133,6 @@ TEST(Prune, multi_target) {
pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true); pdesc->mutable_blocks(0)->mutable_ops(2)->set_is_target(true);
f::ProgramDesc pruned; f::ProgramDesc pruned;
Prune(*pdesc, pruned); Prune(*pdesc, &pruned);
PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3); PADDLE_ENFORCE_EQ(pruned.blocks(0).ops_size(), 3);
} }
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
#include <mutex> // for call_once #include <mutex> // for call_once
#include "glog/logging.h"
#include "paddle/string/printf.h" #include "paddle/string/printf.h"
namespace paddle { namespace paddle {
...@@ -23,7 +24,10 @@ namespace framework { ...@@ -23,7 +24,10 @@ namespace framework {
Scope::~Scope() { Scope::~Scope() {
DropKids(); DropKids();
for (auto& kv : vars_) delete kv.second; for (auto& kv : vars_) {
VLOG(3) << "Destroy variable " << kv.first;
delete kv.second;
}
} }
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {
...@@ -38,6 +42,7 @@ Variable* Scope::Var(const std::string& name) { ...@@ -38,6 +42,7 @@ Variable* Scope::Var(const std::string& name) {
} }
Variable* v = new Variable(); Variable* v = new Variable();
vars_[name] = v; vars_[name] = v;
VLOG(3) << "Create variable " << name << " on scope";
v->name_ = &(vars_.find(name)->first); v->name_ = &(vars_.find(name)->first);
return v; return v;
} }
......
add_subdirectory(detail) add_subdirectory(detail)
cc_library(memory SRCS memory.cc) cc_library(memory SRCS memory.cc DEPS place)
cc_library(memcpy SRCS memcpy.cc) cc_library(memcpy SRCS memcpy.cc)
cc_library(paddle_memory cc_library(paddle_memory
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/memory/detail/meta_cache.h" #include "paddle/memory/detail/meta_cache.h"
#include "glog/logging.h"
#include "paddle/memory/detail/memory_block.h" #include "paddle/memory/detail/memory_block.h"
#include "paddle/platform/assert.h" #include "paddle/platform/assert.h"
...@@ -28,7 +29,9 @@ Metadata MetadataCache::load(const MemoryBlock* block) { ...@@ -28,7 +29,9 @@ Metadata MetadataCache::load(const MemoryBlock* block) {
PADDLE_ASSERT(existing_metadata->second.check_guards()); PADDLE_ASSERT(existing_metadata->second.check_guards());
return existing_metadata->second; return existing_metadata->second;
} else { } else {
PADDLE_ASSERT(reinterpret_cast<const Metadata*>(block)->check_guards()); auto* meta = reinterpret_cast<const Metadata*>(block);
VLOG(3) << "Load MetaData type=" << meta->type;
PADDLE_ASSERT(meta->check_guards());
return *reinterpret_cast<const Metadata*>(block); return *reinterpret_cast<const Metadata*>(block);
} }
} }
......
...@@ -39,11 +39,15 @@ BuddyAllocator* GetCPUBuddyAllocator() { ...@@ -39,11 +39,15 @@ BuddyAllocator* GetCPUBuddyAllocator() {
template <> template <>
void* Alloc<platform::CPUPlace>(platform::CPUPlace place, size_t size) { void* Alloc<platform::CPUPlace>(platform::CPUPlace place, size_t size) {
return GetCPUBuddyAllocator()->Alloc(size); VLOG(3) << "Allocate " << size << " bytes on " << platform::Place(place);
void* p = GetCPUBuddyAllocator()->Alloc(size);
VLOG(3) << " pointer=" << p;
return p;
} }
template <> template <>
void Free<platform::CPUPlace>(platform::CPUPlace place, void* p) { void Free<platform::CPUPlace>(platform::CPUPlace place, void* p) {
VLOG(3) << "Free pointer=" << p << " on " << platform::Place(place);
GetCPUBuddyAllocator()->Free(p); GetCPUBuddyAllocator()->Free(p);
} }
......
...@@ -117,9 +117,6 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> { ...@@ -117,9 +117,6 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
math::SetConstant<platform::GPUPlace, T> functor; math::SetConstant<platform::GPUPlace, T> functor;
functor(ctx.device_context(), saved_mean, 0); functor(ctx.device_context(), saved_mean, 0);
functor(ctx.device_context(), saved_variance, 0); functor(ctx.device_context(), saved_variance, 0);
// FIXME(qiao) should not set zero self
functor(ctx.device_context(), mean_out, 0);
functor(ctx.device_context(), variance_out, 0);
auto handle = ctx.cuda_device_context().cudnn_handle(); auto handle = ctx.cuda_device_context().cudnn_handle();
...@@ -211,8 +208,15 @@ class BatchNormGradKernel<platform::GPUPlace, T> ...@@ -211,8 +208,15 @@ class BatchNormGradKernel<platform::GPUPlace, T>
mode_ = CUDNN_BATCHNORM_SPATIAL; mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif #endif
std::vector<int> dims = {N, C, H, W, D}; std::vector<int> dims;
std::vector<int> strides = {H * W * C * D, 1, W * D * C, D * C, C}; std::vector<int> strides;
if (tensor_format == TensorFormat::NCHW) {
dims = {N, C, H, W, D};
strides = {C * H * W * D, H * W * D, W * D, D, 1};
} else {
dims = {N, C, H, W, D};
strides = {H * W * C * D, 1, W * D * C, D * C, C};
}
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type, data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
......
...@@ -144,11 +144,11 @@ class SequencePoolGradKernel : public framework::OpKernel<T> { ...@@ -144,11 +144,11 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
in_t_map(in_t.data<T>(), h, w); in_t_map(in_t.data<T>(), h, w);
int row_id; int row_id;
Eigen::array<int, 2> extents = {1, 1}; Eigen::array<int, 2> extents{{1, 1}};
for (int col_id = 0; col_id < w; col_id++) { for (int col_id = 0; col_id < w; col_id++) {
in_t_map.col(col_id).maxCoeff(&row_id); in_t_map.col(col_id).maxCoeff(&row_id);
Eigen::array<int, 2> in_offsets = {row_id, col_id}; Eigen::array<int, 2> in_offsets{{row_id, col_id}};
Eigen::array<int, 2> out_offsets = {0, col_id}; Eigen::array<int, 2> out_offsets{{0, col_id}};
in_g_e.slice(in_offsets, extents).device(place) = in_g_e.slice(in_offsets, extents).device(place) =
out_g_e.slice(out_offsets, extents); out_g_e.slice(out_offsets, extents);
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sign_op.h"
namespace paddle {
namespace operators {
class SignOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SignOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SignOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
template <typename AttrType>
class SignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SignOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor of sign operator.");
AddOutput("Out", "(Tensor) Output tensor of sign operator.");
AddComment(R"DOC(Sign operator
The equation is: Out = X.sign()
)DOC");
}
};
class SignGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto *grad_op = new framework::OpDescBind();
grad_op->SetType("scale");
grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttr("scale", 0.0f);
return std::unique_ptr<framework::OpDescBind>(grad_op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker<float>,
ops::SignGradMaker);
REGISTER_OP_CPU_KERNEL(sign,
ops::SignKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sign_op.h"
REGISTER_OP_GPU_KERNEL(
sign, paddle::operators::SignKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class SignKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
auto* out = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
out->mutable_data<T>(in->place());
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& place = context.GetEigenDevice<Place>();
eigen_out.device(place) = eigen_in.sign();
}
};
} // namespace operators
} // namespace paddle
if(WITH_PYTHON) if(WITH_PYTHON)
cc_library(paddle_pybind SHARED cc_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc SRCS pybind.cc exception.cc protobuf.cc
DEPS pybind python backward proto_desc tensor_array paddle_memory executor DEPS pybind python backward proto_desc tensor_array paddle_memory executor prune
${GLOB_OP_LIB}) ${GLOB_OP_LIB})
endif(WITH_PYTHON) endif(WITH_PYTHON)
......
...@@ -141,6 +141,13 @@ void BindProgramDesc(py::module &m) { ...@@ -141,6 +141,13 @@ void BindProgramDesc(py::module &m) {
desc->SerializeToString(&res), desc->SerializeToString(&res),
"Serialize ProgramDesc Error. This could be a bug of Paddle."); "Serialize ProgramDesc Error. This could be a bug of Paddle.");
return res; return res;
})
.def("parse_from_string",
[](ProgramDescBind &program_desc, const std::string &data) {
ProgramDesc *desc = program_desc.Proto();
PADDLE_ENFORCE(desc->ParseFromString(data),
"Fail to parse ProgramDesc from string. This could "
"be a bug of Paddle.");
}); });
} }
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/framework/feed_fetch_method.h" #include "paddle/framework/feed_fetch_method.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/prune.h"
#include "paddle/framework/selected_rows.h" #include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor_array.h" #include "paddle/framework/tensor_array.h"
#include "paddle/operators/cond_op.h" #include "paddle/operators/cond_op.h"
...@@ -237,6 +238,16 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -237,6 +238,16 @@ All parameter, weight, gradient are variables in Paddle.
} }
return ret_values; return ret_values;
}); });
m.def("prune", [](const ProgramDescBind &origin,
const std::vector<std::array<size_t, 2>> &targets) {
ProgramDescBind prog_with_targets(origin);
for (const auto &t : targets) {
prog_with_targets.Block(t[0])->Op(t[1])->MarkAsTarget();
}
ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), &pruned_desc);
return new ProgramDescBind(pruned_desc);
});
m.def_submodule( m.def_submodule(
"var_names", "var_names",
"The module will return special predefined variable name in Paddle") "The module will return special predefined variable name in Paddle")
......
...@@ -251,6 +251,8 @@ class Operator(object): ...@@ -251,6 +251,8 @@ class Operator(object):
self.desc.set_output(out_proto.name, out_argu_names) self.desc.set_output(out_proto.name, out_argu_names)
if attrs is not None: if attrs is not None:
if not isinstance(attrs, dict):
raise TypeError("'attrs' should be a dict.")
for attr in proto.attrs: for attr in proto.attrs:
attr_name = attr.name attr_name = attr.name
if (not attr_name in attrs) or (attrs[attr_name] is None): if (not attr_name in attrs) or (attrs[attr_name] is None):
...@@ -291,6 +293,14 @@ class Operator(object): ...@@ -291,6 +293,14 @@ class Operator(object):
def output_names(self): def output_names(self):
return self.desc.output_names() return self.desc.output_names()
@property
def idx(self):
for i, op in enumerate(self.block.ops):
if op == self:
return i
raise ValueError(
"Can't find op itself in it's block. It could be a bug of Paddle.")
def has_attr(self, name): def has_attr(self, name):
return self.desc.has_attr(name) return self.desc.has_attr(name)
...@@ -440,10 +450,31 @@ class Program(object): ...@@ -440,10 +450,31 @@ class Program(object):
p.sync_with_cpp() p.sync_with_cpp()
return p return p
def prune(self, targets):
if not isinstance(targets, list):
targets = [targets]
targets_idx = []
for t in targets:
if not isinstance(t, Operator):
if isinstance(t, Variable):
t = t.op
else:
raise ValueError(
"All targets of prune() can only be Variable or Operator."
)
targets_idx.append([t.block.idx, t.idx])
res = Program()
res.desc = core.prune(self.desc, targets_idx)
res.blocks = [Block(res, i) for i in xrange(res.desc.num_blocks())]
res.sync_with_cpp()
return res
@staticmethod @staticmethod
def parse_from_string(binary_str): def parse_from_string(binary_str):
p = Program() p = Program()
p.desc = core.ProgramDesc(binary_str) p.desc = core.ProgramDesc(binary_str)
p.blocks = [Block(p, i) for i in xrange(p.desc.num_blocks())]
p.sync_with_cpp() p.sync_with_cpp()
return p return p
......
import os import os
import cPickle as pickle
from paddle.v2.framework.framework import Program, Parameter, g_program, \ from paddle.v2.framework.framework import Program, Parameter, g_program, \
Variable Variable
__all__ = [ __all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables' 'load_persistables', "save_inference_model", "load_inference_model"
] ]
...@@ -31,7 +32,7 @@ def _clone_var_in_block_(block, var): ...@@ -31,7 +32,7 @@ def _clone_var_in_block_(block, var):
def save_vars(executor, dirname, program=None, vars=None, predicate=None): def save_vars(executor, dirname, program=None, vars=None, predicate=None):
""" """
Save variables to directory by executor. Save variables to directory by executor.
:param executor: executor that save variable :param executor: executor that save variable
:param dirname: directory path :param dirname: directory path
:param program: program. If vars is None, then filter all variables in this :param program: program. If vars is None, then filter all variables in this
...@@ -92,7 +93,7 @@ def save_persistables(executor, dirname, program=None): ...@@ -92,7 +93,7 @@ def save_persistables(executor, dirname, program=None):
def load_vars(executor, dirname, program=None, vars=None, predicate=None): def load_vars(executor, dirname, program=None, vars=None, predicate=None):
""" """
Load variables from directory by executor. Load variables from directory by executor.
:param executor: executor that save variable :param executor: executor that save variable
:param dirname: directory path :param dirname: directory path
:param program: program. If vars is None, then filter all variables in this :param program: program. If vars is None, then filter all variables in this
...@@ -124,6 +125,7 @@ def load_vars(executor, dirname, program=None, vars=None, predicate=None): ...@@ -124,6 +125,7 @@ def load_vars(executor, dirname, program=None, vars=None, predicate=None):
inputs={}, inputs={},
outputs={"Out": [new_var]}, outputs={"Out": [new_var]},
attrs={'file_path': os.path.join(dirname, new_var.name)}) attrs={'file_path': os.path.join(dirname, new_var.name)})
executor.run(load_prog) executor.run(load_prog)
...@@ -141,3 +143,88 @@ def load_persistables(executor, dirname, program=None): ...@@ -141,3 +143,88 @@ def load_persistables(executor, dirname, program=None):
""" """
load_vars( load_vars(
executor, dirname=dirname, program=program, predicate=is_persistable) executor, dirname=dirname, program=program, predicate=is_persistable)
def save_inference_model(dirname,
feeded_var_names,
target_vars,
executor,
program=None):
"""
Build a model especially for inference,
and save it to directory by the executor.
:param dirname: directory path
:param feeded_var_names: Names of variables that need to be feeded data during inference
:param target_vars: Variables from which we can get inference results.
:param executor: executor that save inference model
:param program: original program, which will be pruned to build the inference model.
Default g_program.
:return: None
"""
if program is None:
program = g_program
if not isinstance(target_vars, list):
target_vars = [target_vars]
if not os.path.isdir(dirname):
os.makedirs(dirname)
pruned_program = program.prune(target_vars)
fetch_var_names = [v.name for v in target_vars]
model_file_name = dirname + "/__model__"
with open(model_file_name, "w") as f:
pickle.dump({
"program_desc_str": pruned_program.desc.serialize_to_string(),
"feed_var_names": feeded_var_names,
"fetch_var_names": fetch_var_names
}, f, -1)
save_params(executor, dirname, program)
def load_persistables_if_exist(executor, dirname, program=None):
filenames = next(os.walk(dirname))[2]
filenames = set(filenames)
def _is_presistable_and_exist_(var):
if not is_persistable(var):
return False
else:
return var.name in filenames
load_vars(
executor,
dirname,
program=program,
vars=None,
predicate=_is_presistable_and_exist_)
def load_inference_model(dirname, executor):
"""
Load inference model from a directory
:param dirname: directory path
:param executor: executor that load inference model
:return: [program, feed_var_names, fetch_var_names]
program: program especially for inference.
feeded_var_names: Names of variables that need to feed data
fetch_vars: Variables from which we can get inference results.
"""
if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname)
model_file_name = dirname + "/__model__"
model = pickle.load(open(model_file_name, "r"))
program_desc_str = model["program_desc_str"]
feed_var_names = model["feed_var_names"]
fetch_var_names = model["fetch_var_names"]
program = Program.parse_from_string(program_desc_str)
load_persistables_if_exist(executor, dirname, program)
fetch_vars = [program.global_block().var(name) for name in fetch_var_names]
return [program, feed_var_names, fetch_vars]
...@@ -131,12 +131,14 @@ class LayerHelper(object): ...@@ -131,12 +131,14 @@ class LayerHelper(object):
return dtype return dtype
def create_parameter(self, attr, shape, dtype, suffix='w'): def create_parameter(self, attr, shape, dtype, suffix='w'):
if attr['name'] is None: # Deepcopy the attr so that parameters can be shared in program
attr['name'] = unique_name(".".join([self.name, suffix])) attr_copy = copy.deepcopy(attr)
if attr_copy['name'] is None:
attr_copy['name'] = unique_name(".".join([self.name, suffix]))
self.init_program.global_block().create_parameter( self.init_program.global_block().create_parameter(
dtype=dtype, shape=shape, **attr) dtype=dtype, shape=shape, **attr_copy)
return self.program.global_block().create_parameter( return self.program.global_block().create_parameter(
name=attr['name'], dtype=dtype, shape=shape) name=attr_copy['name'], dtype=dtype, shape=shape)
def create_tmp_variable(self, dtype): def create_tmp_variable(self, dtype):
return self.program.current_block().create_var( return self.program.current_block().create_var(
......
...@@ -18,7 +18,8 @@ class Optimizer(object): ...@@ -18,7 +18,8 @@ class Optimizer(object):
but need to use one of it's implementation. but need to use one of it's implementation.
""" """
def __init__(self): def __init__(self, global_step=None):
self._global_step = global_step
# Dictionary of accumulators. Some optimizer subclasses need to # Dictionary of accumulators. Some optimizer subclasses need to
# allocate and manage extra variables associated with the parameters # allocate and manage extra variables associated with the parameters
# to train. These variables are called accumulators. # to train. These variables are called accumulators.
...@@ -109,6 +110,26 @@ class Optimizer(object): ...@@ -109,6 +110,26 @@ class Optimizer(object):
format(name, param.name)) format(name, param.name))
return self._accumulators[name][param.name] return self._accumulators[name][param.name]
def _increment_global_step(self, block):
"""Increment the global step by 1 after every iteration
Args:
block: the block in which the loss variable is present
Returns:
list with global_step increment op as its only element
"""
assert isinstance(block, framework.Block)
assert self._global_step is not None
# create the increment op
increment_op = block.append_op(
type="increment",
inputs={"X": self._global_step},
outputs={"Out": self._global_step},
attrs={"step": 1.0})
return increment_op
def create_optimization_pass(self, parameters_and_grads, loss): def create_optimization_pass(self, parameters_and_grads, loss):
"""Add optimization operators to update gradients to variables. """Add optimization operators to update gradients to variables.
...@@ -152,6 +173,8 @@ class Optimizer(object): ...@@ -152,6 +173,8 @@ class Optimizer(object):
if finish_ops is not None: if finish_ops is not None:
return_ops += finish_ops return_ops += finish_ops
if self._global_step is not None:
return_ops.append(self._increment_global_step(loss.block))
return return_ops return return_ops
def minimize(self, loss, parameter_list=None, no_grad_set=None): def minimize(self, loss, parameter_list=None, no_grad_set=None):
...@@ -172,9 +195,9 @@ class SGDOptimizer(Optimizer): ...@@ -172,9 +195,9 @@ class SGDOptimizer(Optimizer):
""" Simple SGD optimizer without any state. """ Simple SGD optimizer without any state.
""" """
def __init__(self, learning_rate): def __init__(self, learning_rate, global_step=None):
assert learning_rate is not None assert learning_rate is not None
super(SGDOptimizer, self).__init__() super(SGDOptimizer, self).__init__(global_step)
self.type = "sgd" self.type = "sgd"
self._learning_rate = learning_rate self._learning_rate = learning_rate
...@@ -215,10 +238,14 @@ class MomentumOptimizer(Optimizer): ...@@ -215,10 +238,14 @@ class MomentumOptimizer(Optimizer):
""" """
_velocity_acc_str = "velocity" _velocity_acc_str = "velocity"
def __init__(self, learning_rate, momentum, use_nesterov=False): def __init__(self,
learning_rate,
momentum,
use_nesterov=False,
global_step=None):
assert learning_rate is not None assert learning_rate is not None
assert momentum is not None assert momentum is not None
super(MomentumOptimizer, self).__init__() super(MomentumOptimizer, self).__init__(global_step)
self.type = "momentum" self.type = "momentum"
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._momentum = momentum self._momentum = momentum
...@@ -275,10 +302,10 @@ class AdagradOptimizer(Optimizer): ...@@ -275,10 +302,10 @@ class AdagradOptimizer(Optimizer):
""" """
_moment_acc_str = "moment" _moment_acc_str = "moment"
def __init__(self, learning_rate, epsilon=1.0e-6): def __init__(self, learning_rate, epsilon=1.0e-6, global_step=None):
assert learning_rate is not None assert learning_rate is not None
assert epsilon is not None assert epsilon is not None
super(AdagradOptimizer, self).__init__() super(AdagradOptimizer, self).__init__(global_step)
self.type = "adagrad" self.type = "adagrad"
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._epsilon = epsilon self._epsilon = epsilon
...@@ -337,12 +364,13 @@ class AdamOptimizer(Optimizer): ...@@ -337,12 +364,13 @@ class AdamOptimizer(Optimizer):
learning_rate=0.001, learning_rate=0.001,
beta1=0.9, beta1=0.9,
beta2=0.999, beta2=0.999,
epsilon=1e-8): epsilon=1e-8,
global_step=None):
assert learning_rate is not None assert learning_rate is not None
assert beta1 is not None assert beta1 is not None
assert beta2 is not None assert beta2 is not None
assert epsilon is not None assert epsilon is not None
super(AdamOptimizer, self).__init__() super(AdamOptimizer, self).__init__(global_step)
self.type = "adam" self.type = "adam"
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._beta1 = beta1 self._beta1 = beta1
...@@ -458,7 +486,8 @@ class AdamaxOptimizer(Optimizer): ...@@ -458,7 +486,8 @@ class AdamaxOptimizer(Optimizer):
learning_rate=0.001, learning_rate=0.001,
beta1=0.9, beta1=0.9,
beta2=0.999, beta2=0.999,
epsilon=1e-8): epsilon=1e-8,
global_step=None):
assert learning_rate is not None assert learning_rate is not None
assert beta1 is not None assert beta1 is not None
assert beta2 is not None assert beta2 is not None
......
import paddle.v2.framework.framework as framework import paddle.v2.framework.framework as framework
__all__ = ['append_regularization_ops', 'L2DecayRegularizer'] __all__ = [
'append_regularization_ops', 'L2DecayRegularizer', 'L1DecayRegularizer'
]
def append_regularization_ops(parameters_and_grads): def append_regularization_ops(parameters_and_grads):
...@@ -97,3 +99,43 @@ class L2DecayRegularizer(WeightDecayRegularizer): ...@@ -97,3 +99,43 @@ class L2DecayRegularizer(WeightDecayRegularizer):
attrs={"scale": self._regularization_coeff}) attrs={"scale": self._regularization_coeff})
return decay return decay
class L1DecayRegularizer(WeightDecayRegularizer):
"""Implements the L1 Weight Decay Regularization
"""
def __init__(self, regularization_coeff=0.0):
assert regularization_coeff is not None
super(L1DecayRegularizer, self).__init__()
self._regularization_coeff = regularization_coeff
def __call__(self, param, block):
"""Add L1 weight decay ops to network
Adds L1 weight decay ops.
L1WeightDecay = reg_coeff * sign(parameter)
Args:
param: parameter variable for which regularization is applied
block: block in which variable is to be created
Returns:
new variable for weight decay
"""
assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block)
decay = block.create_var(
dtype="float32", shape=param.shape, lod_level=param.lod_level)
# Append sign op
block.append_op(
type='sign', inputs={"X": param}, outputs={"Out": decay})
# Append scale op to the output of sign op
block.append_op(
type='scale',
inputs={"X": decay},
outputs={"Out": decay},
attrs={"scale": self._regularization_coeff})
return decay
...@@ -21,16 +21,36 @@ def get_backward_op(scope, op, no_grad_set): ...@@ -21,16 +21,36 @@ def get_backward_op(scope, op, no_grad_set):
def _reference_training(x, scale, offset, epsilon, data_format): def _reference_training(x, scale, offset, epsilon, data_format):
if data_format != "NHWC": if data_format == "NCHW":
raise ValueError("data_format must be NHWC, got %s." % data_format) n, c, h, w = x.shape
x_square = x * x x_square = x * x
x_square_sum = np.sum(x_square, (0, 1, 2)) x_square_sum = np.sum(x_square, (0, 2, 3))
x_sum = np.sum(x, axis=(0, 1, 2)) x_sum = np.sum(x, axis=(0, 2, 3))
element_count = np.size(x) / int(np.shape(x)[-1]) element_count = np.size(x) / int(np.shape(x)[1])
mean = x_sum / element_count mean = x_sum / element_count
var = x_square_sum / element_count - mean * mean var = x_square_sum / element_count - mean * mean
normalized = (x - mean) / np.sqrt(var + epsilon) mean_tile = np.reshape(mean, (1, c, 1, 1))
return (normalized * scale + offset), mean, var mean_tile = np.tile(mean_tile, (n, 1, h, w))
var_tile = np.reshape(var, (1, c, 1, 1))
var_tile = np.tile(var_tile, (n, 1, h, w))
normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon)
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
offset_tile = np.reshape(offset, (1, c, 1, 1))
offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
y = normalized * scale_tile + offset_tile
return y, mean, var
elif data_format == "NHWC":
x_square = x * x
x_square_sum = np.sum(x_square, (0, 1, 2))
x_sum = np.sum(x, axis=(0, 1, 2))
element_count = np.size(x) / int(np.shape(x)[-1])
mean = x_sum / element_count
var = x_square_sum / element_count - mean * mean
normalized = (x - mean) / np.sqrt(var + epsilon)
return (normalized * scale + offset), mean, var
else:
raise ValueError("Unknown data order.")
def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
...@@ -43,8 +63,13 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): ...@@ -43,8 +63,13 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
# grad_x = # grad_x =
# 1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) - # 1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) -
# (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon)) # (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon))
if data_format != "NHWC":
raise ValueError("data_format must be NHWC, got %s." % data_format) # transfer from (N, C, H, W) to (N, H, W, C) to simplify computation
if data_format == "NCHW":
x = np.transpose(x, (0, 2, 3, 1))
grad_y = np.transpose(grad_y, (0, 2, 3, 1))
# raise ValueError("data_format must be NHWC, got %s." % data_format)
grad_x = scale * (grad_y - np.mean( grad_x = scale * (grad_y - np.mean(
grad_y, axis=(0, 1, 2)) - (x - mean) * np.mean( grad_y, axis=(0, 1, 2)) - (x - mean) * np.mean(
grad_y * (x - mean), axis=(0, 1, 2)) / grad_y * (x - mean), axis=(0, 1, 2)) /
...@@ -52,6 +77,12 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format): ...@@ -52,6 +77,12 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
grad_scale = np.sum(grad_y * (x - mean) / np.sqrt(var + epsilon), grad_scale = np.sum(grad_y * (x - mean) / np.sqrt(var + epsilon),
axis=(0, 1, 2)) axis=(0, 1, 2))
grad_offset = np.sum(grad_y, axis=(0, 1, 2)) grad_offset = np.sum(grad_y, axis=(0, 1, 2))
# transfer back to N, C, H, W
if data_format == "NCHW":
grad_x = np.transpose(grad_x, (0, 3, 1, 2))
x = np.transpose(x, (0, 3, 1, 2))
grad_y = np.transpose(grad_y, (0, 3, 1, 2))
return grad_x, grad_scale, grad_offset return grad_x, grad_scale, grad_offset
...@@ -65,61 +96,135 @@ def create_or_get_tensor(scope, var_name, var, place): ...@@ -65,61 +96,135 @@ def create_or_get_tensor(scope, var_name, var, place):
return tensor return tensor
def set_output_grad(scope, outputs, place): def set_output_grad(scope, outputs, place, feed_dict=None):
def __set_tensor__(name): def __set_tensor__(name, data=None):
out_tensor = scope.find_var(name).get_tensor() out_tensor = scope.find_var(name).get_tensor()
grad_tensor = scope.var(grad_var_name(name)).get_tensor() grad_tensor = scope.var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype() out_dtype = out_tensor.dtype()
if out_dtype == core.DataType.FP64: if data is None:
data = np.ones(out_tensor.shape(), dtype=np.float64) if out_dtype == core.DataType.FP64:
elif out_dtype == core.DataType.FP32: data = np.ones(out_tensor.shape(), dtype=np.float64)
data = np.ones(out_tensor.shape(), dtype=np.float32) elif out_dtype == core.DataType.FP32:
else: data = np.ones(out_tensor.shape(), dtype=np.float32)
raise ValueError("Not supported data type " + str(out_dtype)) else:
raise ValueError("Not supported data type " + str(out_dtype))
grad_tensor.set(data, place) grad_tensor.set(data, place)
for output in outputs: for output in outputs:
__set_tensor__(output) data = None
if output in feed_dict:
data = feed_dict[output]
__set_tensor__(output, data)
class TestBatchNormOp(OpTest): class TestBatchNormOp(OpTest):
def __assert_close(self, tensor, np_array, msg, atol=1e-4): def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def test_forward_backward(self): def test_python(self):
# attr
data_format = "NHWC" data_format = "NHWC"
epsilon = 0.00001 epsilon = 0.00001
momentum = 0.9 momentum = 0.9
channel_num = 2 # N, H, W, C: 2, 3, 4, 2
x_shape = [2, 3, 4, channel_num] n, h, w, c = 2, 3, 4, 2
scale_shape = [channel_num] x_shape = [n, h, w, c]
scale_shape = [c]
# input
x_val = np.random.random_sample(x_shape).astype(np.float32) x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32) bias_val = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(np.float32) mean = np.zeros(scale_shape).astype(np.float32)
variance = np.zeros(scale_shape).astype(np.float32) variance = np.ones(scale_shape).astype(np.float32)
# run forward # run forward
y_out, saved_mean, var_ref = _reference_training( y_out, saved_mean, var_ref = _reference_training(
x_val, scale_val, bias_val, epsilon, data_format) x_val, scale_val, bias_val, epsilon, "NHWC")
#
mean_out = saved_mean * (1. - momentum) + momentum * mean
variance_out = var_ref * (1. - momentum) + momentum * variance
saved_variance = 1. / np.sqrt(var_ref + epsilon)
# running N, C, H, W case
# should produce the same results
x_shape2 = [n, c, h, w]
x_val2 = np.transpose(x_val, (0, 3, 1, 2))
y_out2, saved_mean2, var_ref2 = _reference_training(
x_val2, scale_val, bias_val, epsilon, "NCHW")
self.__assert_close(saved_mean, saved_mean2, "batch mean")
self.__assert_close(var_ref, var_ref2, "batch variance")
# transfer (N, C, H, W) back to (N, H, W, C)
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
self.__assert_close(y_out, y_out2_trans, "batch variance")
print 'python: NHWC, NCHW, forward checking passed'
# test backward now
# NHWC
self.y_grad = np.random.random_sample(x_shape).astype(np.float32)
y_grad = self.y_grad
# y_grad = np.ones(x_shape).astype(np.float32)
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, "NHWC")
# run backward # NCHW
mean_out = saved_mean * (1 - momentum) y_grad2 = np.transpose(y_grad, (0, 3, 1, 2))
variance_out = var_ref * (1 - momentum) # y_grad2 = np.ones(x_shape2).astype(np.float32)
saved_variance = 1 / np.sqrt(var_ref + epsilon) x_grad_ref2, scale_grad_ref2, bias_grad_ref2 = _reference_grad(
x_val2, y_grad2, scale_val, saved_mean2, var_ref2, epsilon, "NCHW")
# for gradient test self.__assert_close(scale_grad_ref, scale_grad_ref2, "scale gradient")
y_grad = np.ones(x_shape).astype(np.float32) self.__assert_close(bias_grad_ref, bias_grad_ref2, "bias gradient")
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, data_format) x_grad_transpose = np.transpose(x_grad_ref2, (0, 2, 3, 1))
self.__assert_close(x_grad_ref, x_grad_transpose, "x gradient")
print 'python: NHWC, NCHW, backward checking passed'
def test_forward_backward(self):
def test_with_place(place, tensor_format):
# attr
epsilon = 0.00001
momentum = 0.9
# N, H, W, C: 12, 3, 4, 2
n, h, w, c = 2, 3, 4, 2
if data_format == "NHWC":
x_shape = [n, h, w, c]
elif data_format == "NCHW":
x_shape = [n, c, h, w]
else:
raise ValueError("Unknown data type.")
scale_shape = [c]
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
# run forward
y_out, saved_mean, var_ref = _reference_training(
x_val, scale_val, bias_val, epsilon, data_format)
# update moving mean and variance
mean_out = saved_mean * (1. - momentum) + momentum * mean
variance_out = var_ref * (1. - momentum) + momentum * variance
saved_variance = 1. / np.sqrt(var_ref + epsilon)
# for gradient test
# y_grad = np.ones(x_shape).astype(np.float32)
y_grad = np.zeros(x_shape).astype(np.float32)
y_grad[0, 0, 0, 0] = 1.
# y_grad = np.random.random_sample(x_shape).astype(np.float32)
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon,
data_format)
def test_with_place(place):
scope = core.Scope() scope = core.Scope()
# create input # create input
...@@ -157,7 +262,7 @@ class TestBatchNormOp(OpTest): ...@@ -157,7 +262,7 @@ class TestBatchNormOp(OpTest):
SavedVariance="saved_variance", SavedVariance="saved_variance",
# attrs # attrs
is_test=False, is_test=False,
tensor_format=data_format, tensor_format=tensor_format,
momentum=momentum, momentum=momentum,
epsilon=epsilon) epsilon=epsilon)
...@@ -170,20 +275,21 @@ class TestBatchNormOp(OpTest): ...@@ -170,20 +275,21 @@ class TestBatchNormOp(OpTest):
self.__assert_close(saved_variance_tensor, saved_variance, self.__assert_close(saved_variance_tensor, saved_variance,
"saved_variance") "saved_variance")
self.__assert_close(mean_out_tensor, mean_out, "mean_out") self.__assert_close(mean_out_tensor, mean_out, "mean_out")
# FIXME(qiao) figure out why with cuDNN variance_out have a higher error rate
if isinstance(place, core.GPUPlace): if isinstance(place, core.GPUPlace):
atol = 5e-2 atol = 5e-2
else: else:
atol = 1e-4 atol = 1e-4
self.__assert_close(variance_out_tensor, variance_out, self.__assert_close(variance_out_tensor, variance_out,
"variance_out", atol) "variance_out", atol)
print "op test forward passed: ", str(place), tensor_format
# run backward # run backward
batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set()) batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set())
set_output_grad( set_output_grad(
scope, scope,
["y_out", "mean", "variance", "saved_mean", "saved_variance"], ["y_out", "mean", "variance", "saved_mean", "saved_variance"],
place) place,
feed_dict={"y_out": y_grad})
batch_norm_op_grad.run(scope, ctx) batch_norm_op_grad.run(scope, ctx)
x_grad_tensor = create_or_get_tensor(scope, x_grad_tensor = create_or_get_tensor(scope,
...@@ -200,12 +306,14 @@ class TestBatchNormOp(OpTest): ...@@ -200,12 +306,14 @@ class TestBatchNormOp(OpTest):
self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad") self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad")
self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad") self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad")
self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad") self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad")
print "op test backward passed: ", str(place), tensor_format
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compile_gpu() and core.op_support_gpu("batch_norm"): if core.is_compile_gpu() and core.op_support_gpu("batch_norm"):
places.append(core.GPUPlace(0)) places.append(core.GPUPlace(0))
for place in places: for place in places:
test_with_place(place) for data_format in ["NCHW", "NHWC"]:
test_with_place(place, data_format)
if __name__ == '__main__': if __name__ == '__main__':
......
import paddle.v2 as paddle
import paddle.v2.framework.layers as layers
import paddle.v2.framework.core as core
import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.framework import Program, g_program
from paddle.v2.framework.io import save_inference_model, load_inference_model
import paddle.v2.framework.executor as executor
import unittest
import numpy as np
class TestBook(unittest.TestCase):
def test_fit_line_inference_model(self):
MODEL_DIR = "./tmp/inference_model"
init_program = Program()
program = Program()
x = layers.data(
name='x',
shape=[2],
data_type='float32',
program=program,
init_program=init_program)
y = layers.data(
name='y',
shape=[1],
data_type='float32',
program=program,
init_program=init_program)
y_predict = layers.fc(input=x,
size=1,
act=None,
program=program,
init_program=init_program)
cost = layers.square_error_cost(
input=y_predict,
label=y,
program=program,
init_program=init_program)
avg_cost = layers.mean(
x=cost, program=program, init_program=init_program)
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
opts = sgd_optimizer.minimize(avg_cost)
place = core.CPUPlace()
exe = executor.Executor(place)
exe.run(init_program, feed={}, fetch_list=[])
for i in xrange(100):
x_data = np.array(
[[1, 1], [1, 2], [3, 4], [5, 2]]).astype("float32")
y_data = np.array([[-2], [-3], [-7], [-7]]).astype("float32")
tensor_x = core.LoDTensor()
tensor_x.set(x_data, place)
tensor_y = core.LoDTensor()
tensor_y.set(y_data, place)
exe.run(program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost])
save_inference_model(MODEL_DIR, ["x", "y"], [avg_cost], exe, program)
outs = exe.run(program,
feed={'x': tensor_x,
'y': tensor_y},
fetch_list=[avg_cost])
expected = np.array(outs[0])
reload(executor) # reload to build a new scope
exe = executor.Executor(place)
[infer_prog, feed_var_names, fetch_vars] = load_inference_model(
MODEL_DIR, exe)
outs = exe.run(
infer_prog,
feed={feed_var_names[0]: tensor_x,
feed_var_names[1]: tensor_y},
fetch_list=fetch_vars)
actual = np.array(outs[0])
self.assertEqual(feed_var_names, ["x", "y"])
self.assertEqual(len(fetch_vars), 1)
self.assertEqual(str(fetch_vars[0]), str(avg_cost))
self.assertEqual(expected, actual)
if __name__ == '__main__':
unittest.main()
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
import numpy
import paddle.v2 as paddle
exit(
0
) # FIXME(yuyang18): InferShape has been removed, this unittest should be changed until compile time is ready
BATCH_SIZE = 100
scope = core.Scope()
place = core.CPUPlace()
# if you want to test GPU training, you can use gpu place
# place = core.GPUPlace(0)
dev_ctx = core.DeviceContext.create(place)
init_net = core.Net.create()
forward_net = core.Net.create()
backward_net = None
optimize_net = core.Net.create()
def atomic_id():
id = 0
while True:
yield id
id += 1
uniq_id = atomic_id().next
def data_layer(name, dims):
var = scope.var(name)
tensor = var.get_tensor()
tensor.set_dims(dims) # 1 is batch size holder.
return name
def feed_data(name, data):
assert isinstance(data, numpy.ndarray)
tensor = scope.find_var(name).get_tensor()
tensor.set_dims(data.shape)
if data.dtype == numpy.dtype("int32"):
tensor.alloc_int(place)
elif data.dtype == numpy.dtype("float32"):
tensor.alloc_float(place)
else:
raise ValueError("data type not supported")
tensor.set(data, place)
def grad_var_name(var_name):
return var_name + "@GRAD"
def sgd_optimizer(net, param_name, learning_rate=0.005):
grad_name = grad_var_name(param_name)
optimize_op = Operator(
"sgd",
param=param_name,
grad=grad_name,
param_out=param_name,
learning_rate=learning_rate)
net.append_op(optimize_op)
# should use operator and add these to the init_network
def init_param(net, param_name, dims):
scope.var(param_name)
op = Operator(
"uniform_random", Out=param_name, dims=dims, min=-0.5, max=0.5, seed=10)
op.infer_shape(scope)
net.append_op(op)
# fc_layer
def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None):
"""
The fully connected layer.
:param input: The name of input variable.
:type input: str
:param size: The size of fully connected layer.
:param act: The name of activation.
:param param: The attribute of learnable parameter which can be used to
modify initialization mean and std of the parameter.
:param bias: The attribute of bias. If set False, this layer does not have
a bias.
:param name: The name of this layer. If it is not set explictly, a name
will be generated automatically.
:return: The name of the output variable.
"""
if name is None:
name = "fc_%d" % uniq_id()
if not isinstance(name, str):
raise ValueError("The name of a layer should be a string.")
input_dims = scope.find_var(input).get_tensor().get_dims()
w_name = param or name + ".w"
init_param(net=init_net, param_name=w_name, dims=[input_dims[1], size])
sgd_optimizer(net=optimize_net, param_name=w_name, learning_rate=0.01)
pre_activation = name + ".mul.out"
scope.var(pre_activation)
mul_op = Operator("mul", X=input, Y=w_name, Out=pre_activation)
net.append_op(mul_op)
# create bias variable if needed
if bias:
bias_name = name + ".b"
init_param(net=init_net, param_name=bias_name, dims=[size])
sgd_optimizer(
net=optimize_net, param_name=bias_name, learning_rate=0.001)
bias_out = name + ".rowwise_add.out"
scope.var(bias_out)
rowwise_append_op = Operator(
"rowwise_add", X=pre_activation, b=bias_name, Out=bias_out)
net.append_op(rowwise_append_op)
pre_activation = bias_out
activation_op = Operator(act, X=pre_activation, Y=name)
net.append_op(activation_op)
scope.var(name)
net.infer_shape(scope)
return name
def cross_entropy_layer(net, input, label):
cost_name = "cross_entropy_%d" % uniq_id()
cross_entropy_op = Operator(
"cross_entropy", X=input, Label=label, Y=cost_name)
net.append_op(cross_entropy_op)
scope.var(cost_name)
net.infer_shape(scope)
return cost_name
def create_backward_net(forward_net):
net = core.Operator.backward(forward_net, set())
for input in net.inputs()["all"]:
var = scope.var(input)
var.get_tensor()
for output in net.outputs()["all"]:
var = scope.var(output)
var.get_tensor()
return net
def debug_print_op(op):
print("===============" + op.type() + "==============")
print("***inputs:***")
for input in op.inputs()["all"]:
print input, scope.find_var(input).get_tensor().get_dims()
print("\n***outputs:***")
for output in op.outputs()["all"]:
print output, scope.find_var(output).get_tensor().get_dims()
print("")
print("")
def set_cost(cost):
cost_shape = numpy.array(scope.find_var(cost).get_tensor()).shape
cost_grad = \
scope.find_var(grad_var_name(cost)).get_tensor()
cost_grad.set_dims(cost_shape)
cost_grad.alloc_float(place)
cost_grad.set(numpy.ones(cost_shape).astype("float32"), place)
def get_cost_mean(cost):
cost_data = numpy.array(scope.find_var(cost).get_tensor())
return cost_data.sum() / len(cost_data)
def error_rate(predict, label):
predict_var = numpy.array(scope.find_var(predict).get_tensor()).argmax(
axis=1)
label = numpy.array(scope.find_var(label).get_tensor())
error_num = numpy.sum(predict_var != label)
return error_num / float(len(label))
images = data_layer(name="pixel", dims=[BATCH_SIZE, 784])
labels = data_layer(name="label", dims=[BATCH_SIZE, 1])
fc1 = fc_layer(net=forward_net, input=images, size=100, act="sigmoid")
fc2 = fc_layer(net=forward_net, input=fc1, size=100, act="sigmoid")
predict = fc_layer(net=forward_net, input=fc2, size=10, act="softmax")
cost = cross_entropy_layer(net=forward_net, input=predict, label=labels)
init_net.complete_add_op(True)
forward_net.complete_add_op(True)
backward_net = create_backward_net(forward_net)
optimize_net.complete_add_op(True)
print(init_net)
print(forward_net)
print(backward_net)
print(optimize_net)
debug_print_op(forward_net)
debug_print_op(backward_net)
debug_print_op(optimize_net)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=BATCH_SIZE)
def test(cost_name):
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
cost = []
error = []
for data in test_reader():
image_data = numpy.array(map(lambda x: x[0], data)).astype("float32")
label_data = numpy.array(map(lambda x: x[1], data)).astype("int32")
label_data = numpy.expand_dims(label_data, axis=1)
feed_data(images, image_data)
feed_data(labels, label_data)
forward_net.infer_shape(scope)
forward_net.run(scope, dev_ctx)
cost.append(get_cost_mean(cost_name))
error.append(error_rate(predict, "label"))
print("cost=" + str(sum(cost) / float(len(cost))) + " error_rate=" + str(
sum(error) / float(len(error))))
PASS_NUM = 1
init_net.run(scope, dev_ctx)
for pass_id in range(PASS_NUM):
batch_id = 0
for data in train_reader():
image_data = numpy.array(map(lambda x: x[0], data)).astype("float32")
label_data = numpy.array(map(lambda x: x[1], data)).astype("int32")
label_data = numpy.expand_dims(label_data, axis=1)
feed_data(images, image_data)
feed_data(labels, label_data)
forward_net.infer_shape(scope)
forward_net.run(scope, dev_ctx)
set_cost(cost)
backward_net.infer_shape(scope)
backward_net.run(scope, dev_ctx)
optimize_net.run(scope, dev_ctx)
if batch_id % 100 == 0:
print("pass[" + str(pass_id) + "] batch_id[" + str(batch_id) + "]")
test(cost)
batch_id = batch_id + 1
import unittest import unittest
from paddle.v2.framework.framework import Variable, g_program from paddle.v2.framework.framework import Variable, Program, g_program
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
...@@ -21,7 +21,8 @@ class TestOperator(unittest.TestCase): ...@@ -21,7 +21,8 @@ class TestOperator(unittest.TestCase):
"Operator \"no_such_op\" has not been registered.") "Operator \"no_such_op\" has not been registered.")
def test_op_desc_creation(self): def test_op_desc_creation(self):
block = g_program.current_block() program = Program()
block = program.current_block()
mul_x = block.create_var( mul_x = block.create_var(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x") dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
mul_y = block.create_var( mul_y = block.create_var(
...@@ -50,10 +51,12 @@ class TestOperator(unittest.TestCase): ...@@ -50,10 +51,12 @@ class TestOperator(unittest.TestCase):
self.assertEqual(mul_op.has_attr("y_num_col_dims"), True) self.assertEqual(mul_op.has_attr("y_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("y_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr_type("y_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("y_num_col_dims"), 1) self.assertEqual(mul_op.attr("y_num_col_dims"), 1)
self.assertEqual(mul_op.idx, 0)
self.assertEqual(mul_out.op, mul_op) self.assertEqual(mul_out.op, mul_op)
def test_mult_input(self): def test_mult_input(self):
block = g_program.current_block() program = Program()
block = program.current_block()
sum_x1 = block.create_var( sum_x1 = block.create_var(
dtype="int", shape=[3, 4], lod_level=0, name="sum.x1") dtype="int", shape=[3, 4], lod_level=0, name="sum.x1")
sum_x2 = block.create_var( sum_x2 = block.create_var(
...@@ -71,6 +74,7 @@ class TestOperator(unittest.TestCase): ...@@ -71,6 +74,7 @@ class TestOperator(unittest.TestCase):
self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"]) self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"])
self.assertEqual(sum_op.output_names, ["Out"]) self.assertEqual(sum_op.output_names, ["Out"])
self.assertEqual(sum_op.output("Out"), ["sum.out"]) self.assertEqual(sum_op.output("Out"), ["sum.out"])
self.assertEqual(sum_op.idx, 0)
self.assertEqual(sum_out.op, sum_op) self.assertEqual(sum_out.op, sum_op)
......
...@@ -27,6 +27,32 @@ class TestOptimizer(unittest.TestCase): ...@@ -27,6 +27,32 @@ class TestOptimizer(unittest.TestCase):
sgd_op = opts[0] sgd_op = opts[0]
self.assertEqual(sgd_op.type, "sgd") self.assertEqual(sgd_op.type, "sgd")
def test_sgd_optimizer_with_global_step(self):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
global_step = block.create_var(
dtype="float32", shape=[1], lod_level=0, name="step")
sgd_optimizer = optimizer.SGDOptimizer(
learning_rate=0.01, global_step=global_step)
opts = sgd_optimizer.minimize(mul_out)
self.assertEqual(len(opts), 2)
sgd_op = opts[0]
self.assertEqual(sgd_op.type, "sgd")
increment_op = opts[1]
self.assertEqual(increment_op.type, "increment")
class TestMomentumOptimizer(unittest.TestCase): class TestMomentumOptimizer(unittest.TestCase):
class MockMomentum(optimizer.MomentumOptimizer): class MockMomentum(optimizer.MomentumOptimizer):
......
...@@ -99,6 +99,8 @@ class TestProgram(unittest.TestCase): ...@@ -99,6 +99,8 @@ class TestProgram(unittest.TestCase):
outputs={"Out": add_out}, outputs={"Out": add_out},
attrs={"x_num_col_dims": 1}) attrs={"x_num_col_dims": 1})
self.assertEqual(mul_op.idx, 0)
self.assertEqual(add_op.idx, 1)
param_to_grad = prog.append_backward(add_out, set()) param_to_grad = prog.append_backward(add_out, set())
def grad_name(name): def grad_name(name):
......
...@@ -5,9 +5,11 @@ import paddle.v2.framework.optimizer as optimizer ...@@ -5,9 +5,11 @@ import paddle.v2.framework.optimizer as optimizer
from paddle.v2.framework.framework import Program, g_program from paddle.v2.framework.framework import Program, g_program
from paddle.v2.framework.executor import Executor from paddle.v2.framework.executor import Executor
from paddle.v2.framework.regularizer import L2DecayRegularizer
import numpy as np import numpy as np
BATCH_SIZE = 128
init_program = Program() init_program = Program()
program = Program() program = Program()
image = layers.data( image = layers.data(
...@@ -17,22 +19,35 @@ image = layers.data( ...@@ -17,22 +19,35 @@ image = layers.data(
program=program, program=program,
init_program=init_program) init_program=init_program)
param_attr = {
'name': None,
'init_attr': {
'type': 'uniform_random',
'min': -1.0,
'max': 1.0
},
'regularization': L2DecayRegularizer(0.0005 * BATCH_SIZE)
}
hidden1 = layers.fc(input=image, hidden1 = layers.fc(input=image,
size=128, size=128,
act='relu', act='relu',
program=program, program=program,
init_program=init_program) init_program=init_program,
param_attr=param_attr)
hidden2 = layers.fc(input=hidden1, hidden2 = layers.fc(input=hidden1,
size=64, size=64,
act='relu', act='relu',
program=program, program=program,
init_program=init_program) init_program=init_program,
param_attr=param_attr)
predict = layers.fc(input=hidden2, predict = layers.fc(input=hidden2,
size=10, size=10,
act='softmax', act='softmax',
program=program, program=program,
init_program=init_program) init_program=init_program,
param_attr=param_attr)
label = layers.data( label = layers.data(
name='y', name='y',
...@@ -48,8 +63,6 @@ avg_cost = layers.mean(x=cost, program=program, init_program=init_program) ...@@ -48,8 +63,6 @@ avg_cost = layers.mean(x=cost, program=program, init_program=init_program)
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001) sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
opts = sgd_optimizer.minimize(avg_cost) opts = sgd_optimizer.minimize(avg_cost)
BATCH_SIZE = 128
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192), paddle.dataset.mnist.train(), buf_size=8192),
......
...@@ -39,5 +39,39 @@ class TestL2DecayRegularizer(unittest.TestCase): ...@@ -39,5 +39,39 @@ class TestL2DecayRegularizer(unittest.TestCase):
self.assertEqual(block.ops[-2].type, 'scale') self.assertEqual(block.ops[-2].type, 'scale')
class TestL1DecayRegularizer(unittest.TestCase):
def test_l2decay_regularizer(self):
program = framework.Program()
block = program.global_block()
mul_x = block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="mul.x",
regularizer=regularizer.L1DecayRegularizer(0.5))
self.assertTrue(mul_x.regularizer is not None)
self.assertTrue(
isinstance(mul_x.regularizer, regularizer.L1DecayRegularizer))
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
block.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
params_grads = append_backward_ops(mul_out)
self.assertEqual(len(params_grads), 1)
count_ops = len(block.ops)
params_grads = optimizer.append_regularization_ops(params_grads)
self.assertEqual(len(params_grads), 1)
self.assertEqual(len(block.ops), count_ops + 3)
self.assertEqual(block.ops[-1].type, 'elementwise_add')
self.assertEqual(block.ops[-2].type, 'scale')
self.assertEqual(block.ops[-3].type, 'sign')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestSignOp(OpTest):
def setUp(self):
self.op_type = "sign"
self.inputs = {
'X': np.random.uniform(-10, 10, (10, 10)).astype("float32")
}
self.outputs = {'Out': np.sign(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册