提交 b430c238 编写于 作者: T tensor-tang

Merge remote-tracking branch 'upstream/develop' into mklso

......@@ -37,8 +37,8 @@ before_install:
- if [[ "$JOB" == "check_style" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
# Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python
# protobuf version.
- pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker
- pip install rarfile nltk==3.2.2 scipy==0.19.0 recordio matplotlib Pillow
- pip install -r $TRAVIS_BUILD_DIR/python/requirements.txt
- pip install wheel sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit LinkChecker
- curl https://glide.sh/get | bash
- eval "$(GIMME_GO_VERSION=1.8.3 gimme)"
- go get -u github.com/alecthomas/gometalinter
......
## Auto Gradient Checker Design
## Backgraound:
- Operator forward computing is easy to check if the result is right because it has a clear definition. **But** backpropagation is a notoriously difficult algorithm to debug and get right:
- 1. you should get the right backpropagation formula according to the forward computation.
- 2. you should implement it right in CPP.
- 3. it's difficult to prepare test data.
- Auto gradient check gets a numeric gradient by forward Operator and use it as a reference of the backward Operator's result. It has several advantages:
- 1. numeric gradient checker only need forward operator.
- 2. user only need to prepare the input data for forward Operator.
## Mathematical Theory
The following two document from stanford has a detailed explanation of how to get numeric gradient and why it's useful.
- [Gradient checking and advanced optimization(en)](http://deeplearning.stanford.edu/wiki/index.php/Gradient_checking_and_advanced_optimization)
- [Gradient checking and advanced optimization(cn)](http://ufldl.stanford.edu/wiki/index.php/%E6%A2%AF%E5%BA%A6%E6%A3%80%E9%AA%8C%E4%B8%8E%E9%AB%98%E7%BA%A7%E4%BC%98%E5%8C%96)
## Numeric Gradient Implementation
### Python Interface
```python
def get_numeric_gradient(op,
input_values,
output_name,
input_to_check,
delta=0.005,
local_scope=None):
"""
Get Numeric Gradient for an operator's input.
:param op: C++ operator instance, could be an network
:param input_values: The input variables. Should be an dictionary, key is
variable name. Value is numpy array.
:param output_name: The final output variable name.
:param input_to_check: The input variable need to get gradient.
:param delta: The perturbation value for numeric gradient method. The
smaller delta is, the more accurate result will get. But if that delta is
too small, it could occur numerical stability problem.
:param local_scope: The local scope used for get_numeric_gradient.
:return: The gradient array in numpy format.
"""
```
### Explaination:
- Why need `output_name`
- One Operator may have multiple Output, you can get independent gradient from each Output. So user should set one output to calculate.
- Why need `input_to_check`
- One operator may have multiple inputs. Gradient Op can calculate the gradient of these Inputs at the same time. But Numeric Gradient needs to calculate them one by one. So `get_numeric_gradient` is designed to calculate the gradient for one input. If you need to compute multiple inputs, you can call `get_numeric_gradient` multiple times.
### Core Algorithm Implementation
```python
# we only compute gradient of one element each time.
# we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size):
# get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i)
# add delta to it, run op and then get the sum of the result tensor.
x_pos = origin + delta
tensor_to_check.set_float_element(i, x_pos)
y_pos = get_output()
# plus delta to this element, run op and get the sum of the result tensor.
x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg)
y_neg = get_output()
# restore old value
tensor_to_check.set_float_element(i, origin)
# compute the gradient of this element and store it into a numpy array.
gradient_flat[i] = (y_pos - y_neg) / delta / 2
# reshape the gradient result to the shape of the source tensor.
return gradient_flat.reshape(tensor_to_check.get_dims())
```
## Auto Graident Checker Framework
Each Operator Kernel has three kinds of Gradient:
- 1. Numeric Gradient
- 2. CPU Operator Gradient
- 3. GPU Operator Gradient(if supported)
Numeric Gradient Only relies on forward Operator. So we use Numeric Gradient as the reference value.
- 1. calculate the numeric gradient.
- 2. calculate CPU kernel Gradient with the backward Operator and compare it with the numeric gradient.
- 3. calculate GPU kernel Gradient with the backward Operator and compare it with the numeric gradient.(if support GPU)
#### Python Interface
```python
def check_grad(self,
forward_op,
input_vars,
inputs_to_check,
output_name,
no_grad_set=None,
only_cpu=False,
max_relative_error=0.005):
"""
:param forward_op: used to create backward_op
:param input_vars: numpy value of input variable. The following
computation will use these variables.
:param inputs_to_check: inputs var names that should check gradient.
:param output_name: output name that used to
:param max_relative_error: The relative tolerance parameter.
:param no_grad_set: used when create backward ops
:param only_cpu: only compute and check gradient on cpu kernel.
:return:
"""
```
### How to check if two numpy array is close enough?
if `abs_numeric_grad` is nearly zero, then use abs error for numeric_grad, not relative
```python
numeric_grad = ...
operator_grad = numpy.array(scope.find_var(grad_var_name(name)).get_tensor())
abs_numeric_grad = numpy.abs(numeric_grad)
# if abs_numeric_grad is nearly zero, then use abs error for numeric_grad, not relative
# error.
abs_numeric_grad[abs_numeric_grad < 1e-3] = 1
diff_mat = numpy.abs(abs_numeric_grad - operator_grad) / abs_numeric_grad
max_diff = numpy.max(diff_mat)
```
#### Notes:
1,The Input data for auto gradient checker should be reasonable to avoid numeric problem.
#### Refs:
- [Gradient checking and advanced optimization(en)](http://deeplearning.stanford.edu/wiki/index.php/Gradient_checking_and_advanced_optimization)
- [Gradient checking and advanced optimization(cn)](http://ufldl.stanford.edu/wiki/index.php/%E6%A2%AF%E5%BA%A6%E6%A3%80%E9%AA%8C%E4%B8%8E%E9%AB%98%E7%BA%A7%E4%BC%98%E5%8C%96)
......@@ -17,12 +17,10 @@ def main():
# network config
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x,
param_attr=paddle.attr.Param(
name='w', learning_rate=1e-3),
param_attr=paddle.attr.Param(name='w'),
size=1,
act=paddle.activation.Linear(),
bias_attr=paddle.attr.Param(
name='b', learning_rate=1e-3))
bias_attr=paddle.attr.Param(name='b'))
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.mse_cost(input=y_predict, label=y)
......
......@@ -41,7 +41,7 @@ ParameterUpdater *ParameterUpdater::createNewRemoteUpdater(
config->m->getConfig(), pserverSpec, useEtcd));
return updater;
#else
throw UnsupportError();
throw UnsupportError("not compiled with WITH_GOLANG");
#endif
}
......
......@@ -36,8 +36,8 @@ py_proto_compile(framework_py_proto SRCS attribute.proto op_proto.proto op_desc.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)
add_custom_command(TARGET framework_py_proto POST_BUILD
COMMAND ${CMAKE_COMMAND} -E make_directory ${PROJ_ROOT}/python/paddle/v2/framework/proto
COMMAND cp *.py ${PROJ_ROOT}/python/paddle/v2/framework/proto/
COMMAND ${CMAKE_COMMAND} -E make_directory ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/proto
COMMAND cp *.py ${PADDLE_SOURCE_DIR}/python/paddle/v2/framework/proto/
COMMENT "Copy generated python proto into directory paddle/v2/framework/proto."
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
......@@ -48,9 +48,12 @@ if(WITH_PYTHON)
cc_library(paddle_pybind SHARED
SRCS pybind.cc
DEPS pybind python backward
fc_op
sgd_op
add_op
mul_op
rowwise_add_op
sigmoid_op
softmax_op
mean_op
cross_entropy_op
recurrent_op
......
......@@ -30,6 +30,8 @@ using DeviceContext = platform::DeviceContext;
class EmptyOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(EmptyOp, OperatorBase)
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
};
......
......@@ -20,7 +20,6 @@ limitations under the License. */
#include "paddle/framework/dim.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/variant.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace framework {
......
......@@ -19,45 +19,44 @@ permissions and limitations under the License. */
namespace paddle {
namespace framework {
class OpRegistry;
using VarIndexMap = std::unordered_map<std::string, int>;
typedef std::vector<int> Ints;
enum class OpArgType { IN, OUT };
static std::vector<int>* GetOpFormat(OperatorBase* op, const OpArgType& type) {
std::string key = type == OpArgType::IN ? "input_format" : "output_format";
return op->attrs_.count(key)
? &boost::get<std::vector<int>>(op->attrs_.at(key))
: nullptr;
const Ints* AttrFormat(const AttributeMap& attrs, const std::string& key) {
return (attrs.count(key) > 0) ? &boost::get<Ints>(attrs.at(key)) : nullptr;
}
static const std::vector<int>* GetOpFormat(const OperatorBase* op,
const OpArgType& type) {
std::string key = type == OpArgType::IN ? "input_format" : "output_format";
return op->attrs_.count(key)
? &boost::get<std::vector<int>>(op->attrs_.at(key))
: nullptr;
Ints* AttrFormat(AttributeMap& attrs, const std::string& key) {
return (attrs.count(key) > 0) ? &boost::get<Ints>(attrs.at(key)) : nullptr;
}
static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
const OpArgType& src_type, const OpArgType& dst_type,
static void TransOpArg(const OperatorBase* src_op,
std::vector<std::string>& grad_inputs,
std::vector<std::string>& grad_outputs,
AttributeMap& grad_attrs,
std::unordered_map<std::string, int>& grad_idxs,
const std::string& src_type, const std::string& dst_type,
int& idx, bool is_grad) {
const std::vector<std::string>& src_inout =
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
const std::vector<int>* src_format = GetOpFormat(src_op, src_type);
(src_type == "input_format") ? src_op->inputs_ : src_op->outputs_;
const std::vector<int>* src_format = AttrFormat(src_op->Attrs(), src_type);
std::vector<std::string>& dst_inout =
dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_;
std::vector<int>* dst_format = GetOpFormat(dst_op, dst_type);
(dst_type == "input_format") ? grad_inputs : grad_outputs;
std::vector<int>* dst_format = AttrFormat(grad_attrs, dst_type);
const OpProto& proto = OpRegistry::protos().at(src_op->type_);
const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
(src_type == "input_format") ? proto.inputs() : proto.outputs();
for (const auto& arg : src_arg_list) {
std::string src_name = arg.name();
std::string dst_name = is_grad ? src_name + kGradVarSuffix : src_name;
(*dst_op->in_out_idxs_)[dst_name] = idx++;
grad_idxs[dst_name] = idx++;
int src_arg_idx = src_op->in_out_idxs_->at(src_name);
int src_begin =
src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx);
......@@ -76,26 +75,42 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
}
OperatorBase* BuildGradOp(const OperatorBase* op) {
std::string grad_op_type = OpRegistry::grad_ops().at(op->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
grad_op->attrs_ = op->attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
if (GetOpFormat(op, OpArgType::IN) != nullptr) {
grad_op->attrs_["output_format"] = std::vector<int>({0});
const std::string& grad_op_type = OpRegistry::grad_ops().at(op->Type());
AttributeMap grad_attrs(op->Attrs());
grad_attrs.erase("input_format");
grad_attrs.erase("output_format");
if (op->Attrs().count("input_format") > 0) {
grad_attrs["output_format"] = std::vector<int>({0});
}
if (GetOpFormat(op, OpArgType::IN) != nullptr ||
GetOpFormat(op, OpArgType::OUT) != nullptr) {
grad_op->attrs_["input_format"] = std::vector<int>({0});
if (op->Attrs().count("input_format") > 0 ||
op->Attrs().count("output_format") > 0) {
grad_attrs["input_format"] = std::vector<int>({0});
}
grad_op->in_out_idxs_.reset(new VarIndexMap());
std::vector<std::string> grad_inputs, grad_outputs;
using VarIndexMap = std::unordered_map<std::string, int>;
VarIndexMap* grad_idxs = new VarIndexMap;
int in_idx = 0;
int out_idx = 0;
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, in_idx, false); // I
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, false); // G
TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, true); // OG
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, out_idx, true); // IG
TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
"input_format", "input_format", in_idx, false); // I
TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
"output_format", "input_format", in_idx, false); // G
TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
"output_format", "input_format", in_idx, true); // OG
TransOpArg(op, grad_inputs, grad_outputs, grad_attrs, *grad_idxs,
"input_format", "output_format", out_idx, true); // IG
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
grad_op->type_ = grad_op_type;
grad_op->inputs_ = grad_inputs;
grad_op->outputs_ = grad_outputs;
grad_op->attrs_ = grad_attrs;
grad_op->in_out_idxs_.reset(grad_idxs);
return grad_op;
}
......
......@@ -10,6 +10,8 @@ namespace framework {
class NOP : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(NOP, OperatorBase)
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
......
......@@ -69,18 +69,18 @@ class OpProtoAndCheckerMaker {
VariableBuilder AddInput(const std::string& name,
const std::string& comment) {
auto input = proto_->mutable_inputs()->Add();
*input->mutable_name() = name;
*input->mutable_comment() = comment;
VarProto* input = proto_->add_inputs();
input->set_name(name);
input->set_comment(comment);
return VariableBuilder{input, [=] { this->SetHasMultipleInput(); },
nullptr};
}
VariableBuilder AddOutput(const std::string& name,
const std::string& comment) {
auto output = proto_->mutable_outputs()->Add();
*output->mutable_name() = name;
*output->mutable_comment() = comment;
VarProto* output = proto_->add_outputs();
output->set_name(name);
output->set_comment(comment);
return VariableBuilder{output, [=] { this->SetHasMultipleOutput(); },
[=] { this->SetHasTemporaryOutput(); }};
}
......@@ -89,17 +89,15 @@ class OpProtoAndCheckerMaker {
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment,
bool generated = false) {
auto attr = proto_->mutable_attrs()->Add();
*attr->mutable_name() = name;
*attr->mutable_comment() = comment;
AttrProto* attr = proto_->add_attrs();
attr->set_name(name);
attr->set_comment(comment);
attr->set_generated(generated);
attr->set_type(AttrTypeID<T>());
return op_checker_->AddAttrChecker<T>(name);
}
void AddComment(const std::string& comment) {
*(proto_->mutable_comment()) = comment;
}
void AddComment(const std::string& comment) { proto_->set_comment(comment); }
private:
void SetHasMultiple(const std::string& in_out, bool* flag) {
......@@ -187,7 +185,7 @@ class OpRegistry {
OpProto& op_proto = protos()[op_type];
auto maker = ProtoMakerType(&op_proto, &op_checker);
maker.Validate();
*op_proto.mutable_type() = op_type;
op_proto.set_type(op_type);
PADDLE_ENFORCE(
op_proto.IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
......@@ -307,22 +305,45 @@ class OpRegistry {
}
};
class Registrar {
public:
// In our design, various kinds of classes, e.g., operators and kernels, have
// their corresponding registry and registrar. The action of registration is
// in the constructor of a global registrar variable, which, however, are not
// used in the code that calls package framework, and would be removed from
// the generated binary file by the linker. To avoid such removal, we add
// Touch to all registrar classes and make USE_OP macros to call this
// method. So, as long as the callee code calls USE_OP, the global
// registrar variable won't be removed by the linker.
void Touch() {}
};
template <typename OpType, typename ProtoMakerType>
class OpRegisterHelper {
class OpRegistrar : public Registrar {
public:
explicit OpRegisterHelper(const char* op_type) {
explicit OpRegistrar(const char* op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
}
};
template <typename GradOpType>
class GradOpRegisterHelper {
class GradOpRegistrar : public Registrar {
public:
GradOpRegisterHelper(const char* op_type, const char* grad_op_type) {
GradOpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
}
};
template <typename PlaceType, typename KernelType>
class OpKernelRegistrar : public Registrar {
public:
explicit OpKernelRegistrar(const char* op_type) {
OperatorWithKernel::OpKernelKey key;
key.place_ = PlaceType();
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType);
}
};
/**
* check if MACRO is used in GLOBAL NAMESPACE.
*/
......@@ -333,97 +354,121 @@ class GradOpRegisterHelper {
msg)
/**
* Macro to Register Operator.
* Macro to register Operator.
*/
#define REGISTER_OP(__op_type, __op_class, __op_maker_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type, \
"REGISTER_OP must be in global namespace"); \
static ::paddle::framework::OpRegisterHelper<__op_class, __op_maker_class> \
__op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; }
#define REGISTER_OP(op_type, op_class, op_maker_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \
__op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
}
/**
* Macro to Register Gradient Operator.
* Macro to register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type##__grad_op_type, \
"REGISTER_GRADIENT_OP must be in global namespace"); \
static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \
__op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \
#__grad_op_type); \
int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
return 0; \
#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##grad_op_type, \
"REGISTER_GRADIENT_OP must be called in global namespace"); \
static ::paddle::framework::GradOpRegistrar<grad_op_class> \
__op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
#grad_op_type); \
int TouchOpGradientRegistrar_##op_type() { \
__op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \
return 0; \
}
/**
* Macro to Forbid user register Gradient Operator.
* Macro to register OperatorKernel.
*/
#define NO_GRADIENT(__op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##__op_type##__op_type##_grad, \
"NO_GRADIENT must be in global namespace")
#define REGISTER_OP_KERNEL(op_type, DEVICE_TYPE, place_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##op_type##_##DEVICE_TYPE##__, \
"REGISTER_OP_KERNEL must be called in global namespace"); \
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__(#op_type); \
int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() { \
__op_kernel_registrar_##op_type##_##DEVICE_TYPE##__.Touch(); \
return 0; \
}
/**
* Macro to Register OperatorKernel.
* Macro to Forbid user register Gradient Operator.
*/
#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##type##_##DEVICE_TYPE##__, \
"REGISTER_OP_KERNEL must be in global namespace"); \
struct __op_kernel_register__##type##__##DEVICE_TYPE##__ { \
__op_kernel_register__##type##__##DEVICE_TYPE##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new __VA_ARGS__()); \
} \
}; \
static __op_kernel_register__##type##__##DEVICE_TYPE##__ \
__reg_kernel_##type##__##DEVICE_TYPE##__; \
int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; }
// (type, KernelType)
#define REGISTER_OP_GPU_KERNEL(type, ...) \
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
// (type, KernelType)
#define REGISTER_OP_CPU_KERNEL(type, ...) \
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
#define NO_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##op_type##_grad, \
"NO_GRADIENT must be called in global namespace")
#define REGISTER_OP_GPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
/**
* Macro to mark what Operator and Kernel we will use and tell the compiler to
* link them into target.
*/
#define USE_OP_WITHOUT_KERNEL(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_without_kernel_##op_type, \
"USE_OP_WITHOUT_KERNEL must be in global namespace"); \
extern int __op_register_##op_type##_handle__(); \
static int __use_op_ptr_##op_type##_without_kernel__ \
__attribute__((unused)) = __op_register_##op_type##_handle__()
#define USE_OP_KERNEL(op_type, DEVICE_TYPE) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
"USE_OP_KERNEL must be in global namespace"); \
extern int __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__(); \
static int __use_op_ptr_##op_type##_##DEVICE_TYPE##_kernel__ \
__attribute__((unused)) = \
__op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__()
// use Operator with only cpu kernel.
#define USE_OP_CPU(op_type) \
USE_OP_WITHOUT_KERNEL(op_type); \
USE_OP_KERNEL(op_type, CPU)
#define USE_OP_ITSELF(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_itself_##op_type, \
"USE_OP_ITSELF must be called in global namespace"); \
extern int TouchOpRegistrar_##op_type(); \
static int use_op_itself_##op_type##_ __attribute__((unused)) = \
TouchOpRegistrar_##op_type()
// TODO(fengjiayi): Most ops' gradient op have not been compeleted. So we use
// `NO_GRAD` to disable micro USE_OP_GRADIENT(op_type). Otherwise the code can't
// be compiled. `NO_GRAD` should be removed after all gradient ops are
// compeleted.
#define NO_GRAD
#ifndef NO_GRAD
#define USE_OP_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_gradient_##op_type, \
"USE_OP_GRADIENT must be called in global namespace"); \
extern int TouchOpGradientRegistrar_##op_type(); \
static int use_op_gradient_##op_type##_ __attribute__((unused)) = \
TouchOpGradientRegistrar_##op_type()
#else
#define USE_OP_GRADIENT(op_type)
#endif
#define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
"USE_OP_DEVICE_KERNEL must be in global namespace"); \
extern int TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE(); \
static int use_op_kernel_##op_type##_##DEVICE_TYPE##_ \
__attribute__((unused)) = \
TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE()
// TODO(fengjiayi): The following macros seems ugly, do we have better method?
#ifdef PADDLE_ONLY_CPU
#define USE_OP(op_type) USE_OP_CPU(op_type)
#define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU)
#else
#define USE_OP(op_type) \
USE_OP_CPU(op_type); \
USE_OP_KERNEL(op_type, GPU)
#define USE_OP_KERNEL(op_type) \
USE_OP_DEVICE_KERNEL(op_type, CPU); \
USE_OP_DEVICE_KERNEL(op_type, GPU)
#endif
#define USE_NO_GRAD_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type)
#define USE_CPU_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU); \
USE_OP_GRADIENT(op_type)
#define USE_OP(op_type) \
USE_NO_GRAD_OP(op_type); \
USE_OP_GRADIENT(op_type)
} // namespace framework
} // namespace paddle
......@@ -7,6 +7,8 @@ namespace paddle {
namespace framework {
class CosineOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(CosineOp, OperatorBase)
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const Scope& scope) const override {}
......@@ -27,6 +29,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(MyTestOp, OperatorBase)
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
......
......@@ -63,6 +63,17 @@ class ExecutionContext;
*/
class OperatorBase {
public:
OperatorBase() {} // TODO(yi): This constructor is to be removed.
OperatorBase(const std::string& type, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const AttributeMap& attrs,
std::unordered_map<std::string, int>* in_out_idxs)
: type_(type),
inputs_(inputs),
outputs_(outputs),
attrs_(attrs),
in_out_idxs_(in_out_idxs) {}
virtual ~OperatorBase() {}
template <typename T>
......@@ -95,16 +106,24 @@ class OperatorBase {
//! Get a input with argument's name described in `op_proto`
const std::string& Input(const std::string& name) const;
//! Get a input which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Inputs(const std::string& name) const;
//! Get a output with argument's name described in `op_proto`
const std::string& Output(const std::string& name) const;
//! Get an output which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Outputs(const std::string& name) const;
const std::string Type() const { return type_; }
const std::vector<std::string> Inputs() const { return inputs_; }
const std::vector<std::string> Outputs() const { return outputs_; }
const AttributeMap& Attrs() const { return attrs_; }
const std::unordered_map<std::string, int>* InOutIdx() const {
return in_out_idxs_.get();
}
public:
std::string type_;
// NOTE: in case of OpGrad, inputs_ contains:
......@@ -281,6 +300,14 @@ class OpKernel {
class OperatorWithKernel : public OperatorBase {
public:
OperatorWithKernel() {} // TODO(yi): This constructor is to be removed.
OperatorWithKernel(const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const AttributeMap& attrs,
std::unordered_map<std::string, int>* in_out_idxs)
: OperatorBase(type, inputs, outputs, attrs, in_out_idxs) {}
struct OpKernelKey {
platform::Place place_;
......@@ -330,5 +357,15 @@ class OperatorWithKernel : public OperatorBase {
virtual void InferShape(const InferShapeContext& ctx) const = 0;
};
#define DEFINE_OPERATOR_CTOR(Class, ParentClass) \
public: \
Class() { /* TODO(yi): This constructor is to be removed. */ \
} \
Class(const std::string& type, const std::vector<std::string>& inputs, \
const std::vector<std::string>& outputs, \
const ::paddle::framework::AttributeMap& attrs, \
std::unordered_map<std::string, int>* in_out_idxs) \
: ParentClass(type, inputs, outputs, attrs, in_out_idxs) {}
} // namespace framework
} // namespace paddle
......@@ -23,6 +23,8 @@ static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(OpWithoutKernelTest, OperatorBase)
void Init() override { x = 1; }
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
......@@ -97,6 +99,8 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
static int cpu_kernel_run_num = 0;
class OpWithKernelTest : public OperatorWithKernel {
public:
DEFINE_OPERATOR_CTOR(OpWithKernelTest, OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {}
};
......@@ -116,6 +120,8 @@ class CPUKernelTest : public OpKernel {
// multiple inputs test
class OperatorMultiInputsTest : public OperatorBase {
public:
DEFINE_OPERATOR_CTOR(OperatorMultiInputsTest, OperatorBase)
void Init() override { x = 1; }
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
......
......@@ -30,16 +30,15 @@ limitations under the License. */
namespace py = pybind11;
USE_OP(add_two);
USE_OP_CPU(onehot_cross_entropy);
USE_OP_WITHOUT_KERNEL(fc);
USE_OP(sgd);
USE_CPU_OP(onehot_cross_entropy);
USE_NO_GRAD_OP(sgd);
USE_OP(mul);
USE_OP(mean);
USE_OP(sigmoid);
USE_OP(softmax);
USE_OP(rowwise_add);
USE_OP(fill_zeros_like);
USE_OP_WITHOUT_KERNEL(recurrent_op);
USE_OP_ITSELF(recurrent_op);
USE_OP(gaussian_random);
USE_OP(uniform_random);
......
......@@ -38,10 +38,11 @@ if(WITH_GPU)
add_simple_unittest(RowConvOpTest)
add_simple_unittest(BlockExpandOpTest)
add_simple_unittest(CropOpTest)
add_simple_unittest(DepthwiseConvOpTest)
endif()
add_simple_unittest(ConvOpTest)
add_simple_unittest(Im2ColTest)
add_simple_unittest(GemmConvOpTest)
endif()
add_style_check_target(paddle_function ${h_files})
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "Function.h"
#include "FunctionTest.h"
namespace paddle {
enum TestType {
kForwardTest = 0,
kBackwardInputTest = 1,
kBackwardFilterTest = 2,
};
template <DeviceType DType1, DeviceType DType2>
class ConvolutionTest {
public:
ConvolutionTest(const std::string& conv1,
const std::string& conv2,
TestType type,
bool useGroups = true,
std::string algo = "auto") {
for (size_t batchSize : {1, 32}) {
for (size_t inputSize : {7, 14, 54}) {
for (size_t filterSize : {1, 3, 5}) {
for (size_t inputChannels : {3, 64}) {
for (size_t outputChannels : {3, 64}) {
if (inputChannels > outputChannels) break;
size_t groups;
if (!useGroups) {
groups = 1;
} else {
if (outputChannels % inputChannels != 0) continue;
groups = inputChannels;
}
for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) {
if (padding >= filterSize) break;
size_t outputSize =
(inputSize - filterSize + 2 * padding + stride) / stride;
VLOG(3) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputSize
<< " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize
<< " stride=" << stride << " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2Function<DType1, DType2> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", groups)
.set("algo", algo));
TensorShape input{
batchSize, inputChannels, inputSize, inputSize};
TensorShape filter;
if (groups > 1)
filter = TensorShape({groups,
outputChannels / groups,
inputChannels / groups,
filterSize,
filterSize});
else
filter = TensorShape({outputChannels,
inputChannels,
filterSize,
filterSize});
TensorShape output{
batchSize, outputChannels, outputSize, outputSize};
if (type == kForwardTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.run();
} else if (type == kBackwardInputTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO);
test.run();
} else if (type == kBackwardFilterTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter),
ADD_TO);
test.run();
}
}
}
}
}
}
}
}
}
};
// Mainly used to test cases where the height and width (input, filter)
// are not equal.
template <DeviceType DType1, DeviceType DType2>
class ConvolutionTest2 {
public:
ConvolutionTest2(const std::string& conv1,
const std::string& conv2,
TestType type,
bool useGroups = true,
std::string algo = "auto") {
for (size_t batchSize : {16}) {
for (size_t inputHeight : {7, 31}) {
for (size_t inputWidth : {10, 54}) {
for (size_t filterHeight : {1, 5}) {
for (size_t filterWidth : {3, 7}) {
for (size_t inputChannels : {7}) {
for (size_t outputChannels : {7}) {
size_t groups;
if (!useGroups) {
groups = 1;
} else {
if (outputChannels % inputChannels != 0) continue;
groups = inputChannels;
}
size_t stride = 1;
size_t padding = 0;
size_t outputHeight =
(inputHeight - filterHeight + 2 * padding + stride) /
stride;
size_t outputWidth =
(inputWidth - filterWidth + 2 * padding + stride) /
stride;
VLOG(3) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputHeight
<< " inputWidth=" << inputWidth
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterHeight
<< " filterWidth=" << filterWidth
<< " outputHeight=" << outputHeight
<< " outputWidth=" << outputWidth
<< " stride=" << stride << " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2Function<DType1, DType2> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", groups)
.set("algo", algo));
TensorShape input{
batchSize, inputChannels, inputHeight, inputWidth};
TensorShape filter;
if (groups > 1)
filter = TensorShape({groups,
outputChannels / groups,
inputChannels / groups,
filterHeight,
filterWidth});
else
filter = TensorShape({outputChannels,
inputChannels,
filterHeight,
filterWidth});
TensorShape output{
batchSize, outputChannels, outputHeight, outputWidth};
if (type == kForwardTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.run();
} else if (type == kBackwardInputTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO);
test.run();
} else if (type == kBackwardFilterTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter),
ADD_TO);
test.run();
}
}
}
}
}
}
}
}
}
};
// ======Start Convolution TEST======
TEST(Forward, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test(
"NaiveConv-CPU", "GemmConv-CPU", kForwardTest, false);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test2(
"NaiveConv-CPU", "GemmConv-CPU", kForwardTest, false);
}
#ifndef PADDLE_ONLY_CPU
TEST(Forward, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConv-CPU", "GemmConv-GPU", kForwardTest, false);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConv-CPU", "GemmConv-GPU", kForwardTest, false);
}
TEST(BackwardInput, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradInput-CPU",
"GemmConvGradInput-GPU",
kBackwardInputTest,
false);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConvGradInput-CPU",
"GemmConvGradInput-GPU",
kBackwardInputTest,
false);
}
TEST(BackwardFilter, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradFilter-CPU",
"GemmConvGradFilter-GPU",
kBackwardFilterTest,
false);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConvGradFilter-CPU",
"GemmConvGradFilter-GPU",
kBackwardFilterTest,
false);
}
#endif
// ======End Convolution TEST======
// ======Start DepthwiseConvolution TEST======
// TODO(zhaolong) The depthwise convolution cpu test will be added when the cpu
// version of depthwiseConv is implemented.
#ifndef PADDLE_ONLY_CPU
TEST(DepthwiseConvForward, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConv-CPU", "DepthwiseConv-GPU", kForwardTest);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConv-CPU", "DepthwiseConv-GPU", kForwardTest);
}
TEST(DepthwiseConvBackwardInput, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradInput-CPU",
"DepthwiseConvGradInput-GPU",
kBackwardInputTest);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConvGradInput-CPU",
"DepthwiseConvGradInput-GPU",
kBackwardInputTest);
}
TEST(DepthwiseConvBackwardFilter, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradFilter-CPU",
"DepthwiseConvGradFilter-GPU",
kBackwardFilterTest);
ConvolutionTest2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test2(
"GemmConvGradFilter-CPU",
"DepthwiseConvGradFilter-GPU",
kBackwardFilterTest);
}
#endif
// ======End DepthwiseConvolution TEST======
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "FunctionTest.h"
namespace paddle {
template <DeviceType DType1, DeviceType DType2>
void forward(Compare2Function<DType1, DType2>& test,
const TensorShape& input,
const TensorShape& filter,
const TensorShape& output) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.run();
}
template <DeviceType DType1, DeviceType DType2>
void backward_input(Compare2Function<DType1, DType2>& test,
const TensorShape& input,
const TensorShape& filter,
const TensorShape& output) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO);
test.run();
}
template <DeviceType DType1, DeviceType DType2>
void backward_filter(Compare2Function<DType1, DType2>& test,
const TensorShape& input,
const TensorShape& filter,
const TensorShape& output) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter), ADD_TO);
test.run();
}
template <DeviceType DType1, DeviceType DType2>
using Function = void (*)(Compare2Function<DType1, DType2>& test,
const TensorShape& input,
const TensorShape& filter,
const TensorShape& output);
/**
* \brief A basic convolution function test interface.
*
* \param conv1 type name of convolution function 1.
* \param conv2 type name of convolution function 2.
* \param function test function, can be one of the forward, backward_input
* backward_filter function.
* Example:
* 1. Compare GemmConv's CPU and GPU implementation:
* Convolution<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>(
* "GemmConv-CPU", "GemmConv-GPU", forward);
*/
template <DeviceType DType1, DeviceType DType2>
void Convolution(const std::string& conv1,
const std::string& conv2,
Function<DType1, DType2> function) {
for (size_t batchSize : {1, 5}) {
for (size_t inputSize : {7, 14, 31}) {
for (size_t filterSize : {1, 3, 5}) {
for (size_t inputChannels : {3, 16}) {
for (size_t outputChannels : {3, 16}) {
if (outputChannels < inputChannels) continue;
for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) {
if (padding >= filterSize) break;
// NNPACK only supports stride = 1 if batchSize > 1
if ((conv1 == "NNPACKConv-CPU" || conv2 == "NNPACKConv-CPU") &&
batchSize > 1 && stride > 1)
break;
size_t outputSize =
(inputSize - filterSize + 2 * padding + stride) / stride;
VLOG(3) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputSize
<< " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize << " stride=" << stride
<< " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2Function<DType1, DType2> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", (std::string) "auto"));
TensorShape input{
batchSize, inputChannels, inputSize, inputSize};
TensorShape filter{
outputChannels, inputChannels, filterSize, filterSize};
TensorShape output{
batchSize, outputChannels, outputSize, outputSize};
function(test, input, filter, output);
}
}
}
}
}
}
}
}
/**
* \brief A convolution function test interface for
* image height is not equal image width.
*/
template <DeviceType DType1, DeviceType DType2>
void Convolution2(const std::string& conv1,
const std::string& conv2,
Function<DType1, DType2> function) {
for (size_t batchSize : {4}) {
for (size_t inputHeight : {7, 31}) {
for (size_t inputWidth : {10, 54}) {
for (size_t filterHeight : {1, 5}) {
for (size_t filterWidth : {3, 7}) {
for (size_t inputChannels : {7}) {
for (size_t outputChannels : {7}) {
size_t stride = 1;
size_t padding = 0;
size_t outputHeight =
(inputHeight - filterHeight + 2 * padding + stride) /
stride;
size_t outputWidth =
(inputWidth - filterWidth + 2 * padding + stride) / stride;
VLOG(3) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputHeight
<< " inputWidth=" << inputWidth
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterHeight
<< " filterWidth=" << filterWidth
<< " outputHeight=" << outputHeight
<< " outputWidth=" << outputWidth
<< " stride=" << stride << " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2Function<DType1, DType2> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", (std::string) "auto"));
TensorShape input{
batchSize, inputChannels, inputHeight, inputWidth};
TensorShape filter{
outputChannels, inputChannels, filterHeight, filterWidth};
TensorShape output{
batchSize, outputChannels, outputHeight, outputWidth};
function(test, input, filter, output);
}
}
}
}
}
}
}
}
/**
* \brief A convolution function test interface for depthwise convolution.
*/
template <DeviceType DType1, DeviceType DType2>
void DepthwiseConvolution(const std::string& conv1,
const std::string& conv2,
Function<DType1, DType2> function) {
for (size_t batchSize : {1, 32}) {
for (size_t inputSize : {7, 14, 54}) {
for (size_t filterSize : {3, 4}) {
for (size_t inputChannels : {32}) {
for (size_t outputChannels : {32, 64}) {
for (size_t stride : {1, 2}) {
for (size_t padding : {0, 1}) {
// NNPACK only supports stride = 1 if batchSize > 1,
// and there has some bug when batchSize > 1 and groups != 1
if ((conv1 == "NNPACKConv-CPU" || conv2 == "NNPACKConv-CPU") &&
batchSize > 1)
break;
size_t outputSize =
(inputSize - filterSize + 2 * padding + stride) / stride;
VLOG(3) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputSize
<< " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize << " stride=" << stride
<< " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
size_t groups = inputChannels;
Compare2Function<DType1, DType2> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", groups)
.set("algo", (std::string) "auto"));
TensorShape input{
batchSize, inputChannels, inputSize, inputSize};
TensorShape filter{groups,
outputChannels / groups,
inputChannels / groups,
filterSize,
filterSize};
TensorShape output{
batchSize, outputChannels, outputSize, outputSize};
function(test, input, filter, output);
}
}
}
}
}
}
}
}
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "ConvOpTest.h"
namespace paddle {
#ifndef PADDLE_ONLY_CPU
TEST(DepthwiseConv, Forward) {
DepthwiseConvolution<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>(
"GemmConv-CPU", "DepthwiseConv-GPU", forward);
}
TEST(DepthwiseConv, BackwardInput) {
DepthwiseConvolution<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>(
"GemmConvGradInput-CPU", "DepthwiseConvGradInput-GPU", backward_input);
}
TEST(DepthwiseConv, BackwardFilter) {
DepthwiseConvolution<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>(
"GemmConvGradFilter-CPU", "DepthwiseConvGradFilter-GPU", backward_filter);
}
#endif
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "ConvOpTest.h"
namespace paddle {
TEST(GemmConv, NaiveConv) {
Convolution<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU>(
"NaiveConv-CPU", "GemmConv-CPU", forward);
Convolution2<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU>(
"NaiveConv-CPU", "GemmConv-CPU", forward);
}
#ifndef PADDLE_ONLY_CPU
TEST(GemmConv, Forward) {
Convolution<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>(
"GemmConv-CPU", "GemmConv-GPU", forward);
Convolution2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>(
"GemmConv-CPU", "GemmConv-GPU", forward);
}
TEST(GemmConv, BackwardInput) {
Convolution<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>(
"GemmConvGradInput-CPU", "GemmConvGradInput-GPU", backward_input);
Convolution2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>(
"GemmConvGradInput-CPU", "GemmConvGradInput-GPU", backward_input);
}
TEST(GemmConv, BackwardFilter) {
Convolution<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>(
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", backward_filter);
Convolution2<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU>(
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", backward_filter);
}
#endif
} // namespace paddle
......@@ -196,30 +196,30 @@ public:
CHECK_EQ(status, nnp_status_success);
}
} else {
for (size_t g = 0; g < groups_; g++) {
// only supports stride = 1
CHECK_EQ(strideH(), 1);
CHECK_EQ(strideW(), 1);
nnp_status status =
nnp_convolution_output(algorithm_,
batchSize,
inputChannels / groups_,
outputChannels / groups_,
inputSize,
padding,
kernelSize,
inputData + inputOffset * g,
filterData + filterOffset * g,
nullptr, /* bias */
outputData + outputOffset * g,
bufferPtr,
sizePtr,
nnp_activation_identity,
nullptr,
threadpool_, /* threadpool */
nullptr);
CHECK_EQ(status, nnp_status_success);
}
// only supports stride = 1
CHECK_EQ(strideH(), 1);
CHECK_EQ(strideW(), 1);
// TODO(hedaoyuan): There has some bug when batchSize > 1 and groups_ > 1.
CHECK_EQ(groups_, static_cast<size_t>(1));
nnp_status status = nnp_convolution_output(algorithm_,
batchSize,
inputChannels,
outputChannels,
inputSize,
padding,
kernelSize,
inputData,
filterData,
nullptr, /* bias */
outputData,
bufferPtr,
sizePtr,
nnp_activation_identity,
nullptr,
threadpool_, /* threadpool */
nullptr);
CHECK_EQ(status, nnp_status_success);
}
}
......
......@@ -13,87 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/function/Function.h"
#include "paddle/function/FunctionTest.h"
DEFINE_string(algo,
"auto",
"The algorithm (auto, ft8x8, ft16x16, wt8x8, "
"implicit-gemm, or direct) for computing convolution of NNPACK.");
#include "paddle/function/ConvOpTest.h"
namespace paddle {
#define IS_NNPACK_SUPPORT(algo, filterSize, stride) \
if (algo == "direct" && filterSize != 1) continue; \
if (algo == "direct" && batchSize != 1) continue; \
if (algo == "wt8x8" && filterSize != 3) continue; \
if (algo == "implicit-gemm" && batchSize != 1) continue; \
if (algo != "auto" && algo != "implicit-gemm" && stride > 1) continue;
class ConvolutionTest {
public:
ConvolutionTest(const std::string& conv1,
const std::string& conv2,
std::string algo = "auto") {
for (size_t batchSize : {1, 32}) {
for (size_t inputSize : {7, 14, 54}) {
for (size_t filterSize : {1, 3, 5}) {
for (size_t inputChannels : {3, 64}) {
for (size_t outputChannels : {3, 64, 128}) {
if (inputChannels < outputChannels) break;
for (size_t stride : {1, 2}) {
// if batchSize > 1 NNPACKConv only supports stride = 1
if (batchSize > 1 && stride > 1) break;
for (size_t padding : {0, 1}) {
if (padding >= filterSize) break;
size_t outputSize =
(inputSize - filterSize + 2 * padding + stride) / stride;
IS_NNPACK_SUPPORT(algo, filterSize, stride);
LOG(INFO) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputSize
<< " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize
<< " stride=" << stride << " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2Function<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", algo));
TensorShape shape0{
batchSize, inputChannels, inputSize, inputSize};
TensorShape shape1{
outputChannels, inputChannels, filterSize, filterSize};
TensorShape shape2{
batchSize, outputChannels, outputSize, outputSize};
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape0));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, shape1));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, shape2));
test.run();
}
}
}
}
}
}
}
}
};
TEST(NNPACK, Forward) {
Convolution<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU>(
"GemmConv-CPU", "NNPACKConv-CPU", forward);
}
TEST(Convolution, NNPACK) {
// NNPACK only supports stride = 1
ConvolutionTest test("GemmConv-CPU", "NNPACKConv-CPU", FLAGS_algo);
TEST(NNPACK, Depthwise) {
DepthwiseConvolution<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU>(
"GemmConv-CPU", "NNPACKConv-CPU", forward);
}
} // namespace paddle
......@@ -112,7 +112,6 @@ BEGIN_DEFINE_ACTIVATION(softmax)
private:
MatrixPtr sftMaxSum_;
MatrixPtr sftMaxDot_;
MatrixPtr one_;
public:
Error __must_check forward(Argument& act) {
......@@ -138,14 +137,6 @@ Error __must_check backward(Argument& act) {
1,
/* trans */ false,
useGpu(act.deviceId));
if (!one_ || one_->getWidth() != outputG->getWidth()) {
Matrix::resizeOrCreate(one_,
1,
outputG->getWidth(),
/* trans */ false,
useGpu(act.deviceId));
one_->one();
}
sftMaxDot_->dotMul(*outputG, *outputV);
sftMaxSum_->colMerge(*sftMaxDot_);
......
......@@ -61,9 +61,6 @@ op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(fc_op
SRCS fc_op.cc
DEPS mul_op rowwise_add_op sigmoid_op softmax_op net_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS op_desc tensor op_registry operator net_op)
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class AddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2);
......@@ -47,6 +48,7 @@ The equation is: Out = X + Y
};
class AddOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(AddOpGrad, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
};
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2,
......@@ -38,6 +39,8 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
};
class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(OnehotCrossEntropyGradientOp,
framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/net_op.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using OpRegistry = framework::OpRegistry;
class FullyConnectedOp : public NetOp {
public:
void Init() override {
AddOp(OpRegistry::CreateOp("mul",
{
Input("X"), Input("W"),
},
{Output("before_act")}, {}));
auto b = Input("b");
if (b != framework::kEmptyVarName) {
AddOp(OpRegistry::CreateOp("rowwise_add",
{Output("before_act"), Input("b")},
{Output("before_act")}, {}));
}
auto activation = GetAttr<std::string>("activation");
AddOp(OpRegistry::CreateOp(activation, {Output("before_act")},
{Output("Y")}, {}));
CompleteAddOp(false);
}
};
class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FullyConnectedOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input of fc operator");
AddInput("W", "the weight of fc operator");
AddInput("b", "the bias of fc operator");
AddOutput("Y", "the output of fc operator");
AddOutput("before_act", "the before activation output of fc operator")
.SetTemporary();
AddAttr<std::string>("activation", "The activation key for fc layer")
.SetDefault("sigmoid")
.InEnum({"sigmoid", "softmax"});
//! TODO(yuyang18): Complete comment;
AddComment("FullyConnected Operator");
}
};
} // namespace operators
} // namespace paddle
USE_OP(mul);
USE_OP(rowwise_add);
USE_OP(sigmoid);
USE_OP(softmax);
namespace ops = paddle::operators;
REGISTER_OP(fc, ops::FullyConnectedOp, ops::FullyConnectedOpMaker);
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(FillZerosLikeOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL,
......
......@@ -43,6 +43,7 @@ class GaussianRandomKernel : public framework::OpKernel {
};
class GaussianRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(GaussianRandomOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0);
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class MeanOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1, "Input size of AddOp must be one");
......@@ -39,6 +40,7 @@ class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
};
class MeanGradOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MeanGradOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(framework::GradVarName("X"))
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class MulOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs");
......@@ -53,6 +54,7 @@ The equation is: Out = X * Y
};
class MulOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(MulOpGrad, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {}
std::string DebugString() const override {
......
......@@ -35,6 +35,8 @@ namespace operators {
*/
class NetOp : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(NetOp, framework::OperatorBase)
/**
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
......
......@@ -12,6 +12,8 @@ static int run_cnt = 0;
class TestOp : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(TestOp, framework::OperatorBase)
void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
......@@ -21,6 +23,8 @@ class TestOp : public framework::OperatorBase {
class EmptyOp : public framework::OperatorBase {
public:
DEFINE_OPERATOR_CTOR(EmptyOp, framework::OperatorBase)
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
};
......
......@@ -100,6 +100,7 @@ class RecurrentGradientAlgorithm {
};
class RecurrentOp final : public framework::OperatorBase {
DEFINE_OPERATOR_CTOR(RecurrentOp, framework::OperatorBase)
public:
void Init() override;
......
......@@ -395,4 +395,4 @@ TEST(RecurrentOp, LinkMemories) {
USE_OP(add_two);
USE_OP(mul);
USE_OP_WITHOUT_KERNEL(recurrent_op);
USE_OP_ITSELF(recurrent_op);
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class RowWiseAddOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(RowWiseAddOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2UL,
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class SGDOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SGDOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2, "Input size of SGDOp must be two");
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class SigmoidOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input");
......@@ -38,6 +39,7 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
};
class SigmoidOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SigmoidOpGrad, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
class SoftmaxOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 1UL,
......@@ -42,6 +43,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
};
class SoftmaxOpGrad : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(SoftmaxOpGrad, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL,
......
......@@ -46,6 +46,7 @@ class CPUUniformRandomKernel : public framework::OpKernel {
};
class UniformRandomOp : public framework::OperatorWithKernel {
DEFINE_OPERATOR_CTOR(UniformRandomOp, framework::OperatorWithKernel)
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"),
......
......@@ -66,28 +66,92 @@ void NewRemoteParameterUpdater::init(
// from parameter server
if (paddle_begin_init_params(parameterClient_)) {
LOG(INFO) << "paddle_begin_init_params start";
// NOTE: convert V1 OptimizatioinConfig proto to V2 OptimizerConfig.
// This makes golang pserver compatible with handy V1 demos.
// TODO(wuyi): Refine or remove these ugly converting lines
OptimizerConfig optimizerConfigV2;
if (trainerConfig_.learning_method() == "momentum") {
optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
} else if (trainerConfig_.learning_method() == "adagrad") {
optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adagrad);
optimizerConfigV2.mutable_adagrad()->set_epsilon(
trainerConfig_.ada_epsilon());
} else if (trainerConfig_.learning_method() == "adadelta") {
optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adagrad);
optimizerConfigV2.mutable_adadelta()->set_epsilon(
trainerConfig_.ada_epsilon());
optimizerConfigV2.mutable_adadelta()->set_rho(trainerConfig_.ada_rou());
} else if (trainerConfig_.learning_method() == "adam") {
optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::Adam);
optimizerConfigV2.mutable_adam()->set_beta_1(trainerConfig_.adam_beta1());
optimizerConfigV2.mutable_adam()->set_beta_2(trainerConfig_.adam_beta2());
optimizerConfigV2.mutable_adam()->set_epsilon(
trainerConfig_.adam_epsilon());
} else {
LOG(ERROR) << "got unsupported v1 optimizer config: "
<< trainerConfig_.learning_method();
optimizerConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
}
if (trainerConfig_.learning_rate_schedule() == "constant") {
optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
optimizerConfigV2.mutable_const_lr()->set_learning_rate(
trainerConfig_.learning_rate());
} else if (trainerConfig_.learning_rate_schedule() == "linear") {
optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Linear);
optimizerConfigV2.mutable_linear_lr()->set_learning_rate(
trainerConfig_.learning_rate());
optimizerConfigV2.mutable_linear_lr()->set_lr_decay_a(
trainerConfig_.learning_rate_decay_a());
optimizerConfigV2.mutable_linear_lr()->set_lr_decay_b(
trainerConfig_.learning_rate_decay_b());
} else {
LOG(ERROR) << "got unsupported v1 learning_rate_schedule config: "
<< trainerConfig_.learning_rate_schedule() << ", set to const";
optimizerConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
}
// overwrite optimizerConfigV2 for per-parameter(layer) configs
for (int i = 0; i < parameterSize(); ++i) {
auto paramConfig = parameters_[i]->getConfig();
LOG(INFO) << "old param config: " << paramConfig.DebugString();
// FIXME(typhoonzero): convert old paramConfig to optimizerConfig
OptimizerConfig optimizeConfigV2;
auto sgdConfigV2 = optimizeConfigV2.mutable_sgd();
sgdConfigV2->set_momentum(paramConfig.momentum());
sgdConfigV2->set_decay(paramConfig.decay_rate());
optimizeConfigV2.set_lr_policy(paddle::OptimizerConfig::Const);
auto constlr = optimizeConfigV2.mutable_const_lr();
if (paramConfig.has_momentum() &&
trainerConfig_.learning_method() == "momentum") {
optimizerConfigV2.mutable_sgd()->set_momentum(paramConfig.momentum());
}
if (paramConfig.has_learning_rate()) {
constlr->set_learning_rate(paramConfig.learning_rate());
} else {
constlr->set_learning_rate(trainerConfig_.learning_rate());
switch (optimizerConfigV2.lr_policy()) {
case 0:
optimizerConfigV2.mutable_const_lr()->set_learning_rate(
paramConfig.learning_rate());
break;
case 1:
optimizerConfigV2.mutable_linear_lr()->set_learning_rate(
paramConfig.learning_rate());
break;
}
}
if (trainerConfig_.algorithm() == "sgd") {
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
// FIXME: config all algorithms
} else {
optimizeConfigV2.set_optimizer(paddle::OptimizerConfig::SGD);
if (paramConfig.has_decay_rate()) {
switch (optimizerConfigV2.optimizer()) {
case 1: // SGD
optimizerConfigV2.mutable_sgd()->set_decay(
paramConfig.decay_rate());
break;
case 2: // Adadelta
optimizerConfigV2.mutable_adadelta()->set_decay(
paramConfig.decay_rate());
break;
case 3: // Adagrad
optimizerConfigV2.mutable_adagrad()->set_decay(
paramConfig.decay_rate());
break;
case 4: // Adam
optimizerConfigV2.mutable_adam()->set_decay(
paramConfig.decay_rate());
break;
}
}
std::string bytes = optimizeConfigV2.SerializeAsString();
// send param and config to pserver
std::string bytes = optimizerConfigV2.SerializeAsString();
const char *array = bytes.data();
int size = (int)bytes.size();
paddle_init_param(
......
py_test(test_net SRCS test_net.py)
py_test(test_fc_op SRCS test_fc_op.py)
py_test(test_scope SRCS test_scope.py)
py_test(test_tensor SRCS test_tensor.py)
......
......@@ -73,21 +73,35 @@ def get_numeric_gradient(op,
def product(dim):
return reduce(lambda a, b: a * b, dim, 1)
# get the input tensor that we want to get it's numeric gradient.
tensor_to_check = local_scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.get_dims())
# prepare a numpy array to store the gradient.
gradient_flat = numpy.zeros(shape=(tensor_size, ), dtype='float32')
# we only compute gradient of one element each time.
# we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size):
# get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i)
# add delta to it, run op and then get the sum of the result tensor.
x_pos = origin + delta
tensor_to_check.set_float_element(i, x_pos)
y_pos = get_output()
# plus delta to this element, run op and get the sum of the result tensor.
x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg)
y_neg = get_output()
tensor_to_check.set_float_element(i, origin) # restore old value
# restore old value
tensor_to_check.set_float_element(i, origin)
# compute the gradient of this element and store it into a numpy array.
gradient_flat[i] = (y_pos - y_neg) / delta / 2
# reshape the gradient result to the shape of the source tensor.
return gradient_flat.reshape(tensor_to_check.get_dims())
......
import paddle.v2.framework.core as core
import unittest
import numpy
from paddle.v2.framework.op import Operator
class TestFc(unittest.TestCase):
def test_fc(self):
scope = core.Scope()
place = core.CPUPlace()
x = scope.new_var("X")
x_tensor = x.get_tensor()
x_tensor.set_dims([1000, 784])
x_tensor.alloc_float(place)
w = scope.new_var("W")
w_tensor = w.get_tensor()
w_tensor.set_dims([784, 100])
w_tensor.alloc_float(place)
w_tensor.set(numpy.random.random((784, 100)).astype("float32"), place)
# Set a real numpy array here.
# x_tensor.set(numpy.array([]))
op = Operator("fc", X="X", Y="Y", W="W")
for out in op.outputs():
if scope.find_var(out) is None:
scope.new_var(out).get_tensor()
tensor = scope.find_var("Y").get_tensor()
op.infer_shape(scope)
self.assertEqual([1000, 100], tensor.shape())
ctx = core.DeviceContext.create(place)
op.run(scope, ctx)
# After complete all ops, check Y is expect or not.
if __name__ == '__main__':
unittest.main()
......@@ -3,6 +3,15 @@ from paddle.v2.framework.op import Operator
import unittest
def fc(X, W, Y):
ret_v = core.Net.create()
ret_v.add_op(Operator("mul", X="X", Y="W", Out="pre_activation"))
ret_v.add_op(Operator("sigmoid", X="pre_activation", Y=Y))
ret_v.complete_add_op(True)
return ret_v
class TestNet(unittest.TestCase):
def test_net_all(self):
net = core.Net.create()
......@@ -10,18 +19,18 @@ class TestNet(unittest.TestCase):
net.add_op(op1)
net2 = core.Net.create()
net2.add_op(Operator("fc", X="X", W="w", Y="fc.out"))
net2.add_op(fc(X="X", W="w", Y="fc.out"))
net2.complete_add_op(True)
net.add_op(net2)
net.complete_add_op(True)
expected = '''
Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, Out, fc.out).
Op(plain_net), inputs:(W, X, Y), outputs:(Out, fc.out, pre_activation).
Op(add_two), inputs:(X, Y), outputs:(Out).
Op(plain_net), inputs:(@EMPTY@, X, w), outputs:(@TEMP@fc@0, fc.out).
Op(fc), inputs:(X, w, @EMPTY@), outputs:(fc.out, @TEMP@fc@0).
Op(mul), inputs:(X, w), outputs:(@TEMP@fc@0).
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc.out).
Op(plain_net), inputs:(W, X), outputs:(fc.out, pre_activation).
Op(plain_net), inputs:(W, X), outputs:(fc.out, pre_activation).
Op(mul), inputs:(X, W), outputs:(pre_activation).
Op(sigmoid), inputs:(pre_activation), outputs:(fc.out).
'''
self.assertEqual(expected, "\n" + str(net))
......
import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils
import paddle.trainer_config_helpers.optimizers as v1_optimizers
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Optimizers(update equation) for SGD method.
TODO(zhihong) : create new optimizer with proto config, add new optimizer here
TODO(yuyang18): Complete comments.
"""
import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils
import paddle.trainer_config_helpers.optimizers as v1_optimizers
from paddle.proto.OptimizerConfig_pb2 import OptimizerConfig
__all__ = [
'Momentum', 'Adam', 'Adamax', 'AdaGrad', 'DecayedAdaGrad', 'AdaDelta',
'RMSProp', 'ModelAverage', 'L2Regularization'
......@@ -70,7 +83,8 @@ class Optimizer(object):
gradient_machine.prefetch(in_args)
parameter_updater.getParametersRemote()
:param pserver_spec: pserver location, eg: localhost:3000
:param pserver_spec: pserver location, eg: localhost:3000, if use etcd,
pserver_spec should be the etcd endpoints, eg: http://localhost:2379
:return: parameter_updater
"""
if is_local:
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
import paddle.trainer.config_parser as cp
......@@ -113,16 +127,7 @@ class Parameters(object):
"""
return iter(self.__param_conf__)
def __getitem__(self, key):
"""
Get parameter by parameter name. It uses Python dict syntax.
:note: It will always copy the parameter from C++ side.
:param key: Parameter name
:type key: basestring
:return: parameter value
:rtype: np.ndarray
"""
def __getter_inner(self, key, param_type):
import py_paddle.swig_paddle as api
shape = self.get_shape(key)
......@@ -138,7 +143,7 @@ class Parameters(object):
each_gradient_machine, key)
# for simplify implementation now, we always copy from C++
assert isinstance(param, api.Parameter)
val = param.getBuf(api.PARAMETER_VALUE)
val = param.getBuf(param_type)
assert isinstance(val, api.Vector)
val = val.copyToNumpyArray()
return val
......@@ -146,6 +151,19 @@ class Parameters(object):
raise RuntimeError("Unexpected branch")
def __getitem__(self, key):
"""
Get parameter by parameter name. It uses Python dict syntax.
:note: It will always copy the parameter from C++ side.
:param key: Parameter name
:type key: basestring
:return: parameter value
:rtype: np.ndarray
"""
import py_paddle.swig_paddle as api
return self.__getter_inner(key, api.PARAMETER_VALUE)
def get_shape(self, key):
"""
get shape of the parameter.
......@@ -202,6 +220,19 @@ class Parameters(object):
"""
return self.__getitem__(key=parameter_name)
def get_grad(self, key):
"""
Get grandient by parameter name.
:note: It will always copy the parameter from C++ side.
:param key: parameter name
:type key: basestring
:return: The grandient matrix.
:rtype: np.ndarray
"""
import py_paddle.swig_paddle as api
return self.__getter_inner(key, api.PARAMETER_GRADIENT)
def set(self, parameter_name, value):
"""
Set parameter by parameter name & matrix.
......@@ -250,7 +281,13 @@ class Parameters(object):
size = reduce(lambda a, b: a * b, param.shape)
f.write(struct.pack("IIQ", 0, 4, size))
param = param.astype(np.float32)
f.write(param.tostring())
s = param.tostring()
wrote_size = 0
buf = buffer(s, wrote_size, 65535)
while buf: # f.write crashes with big data blog.
f.write(buf)
wrote_size += 65535
buf = buffer(s, wrote_size, 65535)
def deserialize(self, name, f):
"""
......
......@@ -161,14 +161,14 @@ class SGD(object):
self.__parameter_updater__.update(each_param)
cost_sum = out_args.sum()
cost = cost_sum / len(data_batch)
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()
event_handler(
v2_event.EndIteration(
pass_id=pass_id,
batch_id=batch_id,
cost=cost,
evaluator=batch_evaluator))
self.__parameter_updater__.finishBatch(cost)
batch_evaluator.finish()
self.__parameter_updater__.finishPass()
pass_evaluator.finish()
......
requests==2.9.2
numpy>=1.12
protobuf==3.1
recordio
matplotlib
rarfile
scipy>=0.19.0
Pillow
nltk>=3.2.2
from setuptools import setup, Distribution
class BinaryDistribution(Distribution):
def has_ext_modules(foo):
return True
......@@ -18,15 +17,8 @@ packages=['paddle',
'paddle.v2.framework.proto',
'py_paddle']
setup_requires=["requests",
"numpy>=1.12",
"protobuf==3.1",
"recordio",
"matplotlib",
"rarfile",
"scipy>=0.19.0",
"Pillow",
"nltk>=3.2.2"]
with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f:
setup_requires = f.read().splitlines()
if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']:
setup_requires+=["opencv-python"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册