提交 816e556b 编写于 作者: L Luo Tao

Merge branch 'develop' into fluid_infer

......@@ -37,6 +37,7 @@ Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddl
- Optimized math operations through SSE/AVX intrinsics, BLAS libraries
(e.g. MKL, OpenBLAS, cuBLAS) or customized CPU/GPU kernels.
- Optimized CNN networks through MKL-DNN library.
- Highly optimized recurrent networks which can handle **variable-length**
sequence without padding.
- Optimized local and distributed training for models with high dimensional
......
......@@ -358,3 +358,132 @@ reduce_min
.. autofunction:: paddle.v2.fluid.layers.reduce_min
:noindex:
logsigmoid
----------
.. autofunction:: paddle.v2.fluid.layers.logsigmoid
:noindex:
exp
---
.. autofunction:: paddle.v2.fluid.layers.exp
:noindex:
relu
----
.. autofunction:: paddle.v2.fluid.layers.relu
:noindex:
tanh
----
.. autofunction:: paddle.v2.fluid.layers.tanh
:noindex:
tanh_shrink
-----------
.. autofunction:: paddle.v2.fluid.layers.tanh_shrink
:noindex:
softshrink
----------
.. autofunction:: paddle.v2.fluid.layers.softshrink
:noindex:
sqrt
----
.. autofunction:: paddle.v2.fluid.layers.sqrt
:noindex:
abs
----
.. autofunction:: paddle.v2.fluid.layers.abs
:noindex:
ceil
----
.. autofunction:: paddle.v2.fluid.layers.ceil
:noindex:
floor
-----
.. autofunction:: paddle.v2.fluid.layers.floor
:noindex:
round
-----
.. autofunction:: paddle.v2.fluid.layers.round
:noindex:
reciprocal
----------
.. autofunction:: paddle.v2.fluid.layers.reciprocal
:noindex:
log
---
.. autofunction:: paddle.v2.fluid.layers.log
:noindex:
square
------
.. autofunction:: paddle.v2.fluid.layers.square
:noindex:
softplus
--------
.. autofunction:: paddle.v2.fluid.layers.softplus
:noindex:
softsign
---------
.. autofunction:: paddle.v2.fluid.layers.softsign
:noindex:
brelu
-----
.. autofunction:: paddle.v2.fluid.layers.brelu
:noindex:
leaky_relu
----------
.. autofunction:: paddle.v2.fluid.layers.leaky_relu
:noindex:
soft_relu
---------
.. autofunction:: paddle.v2.fluid.layers.soft_relu
:noindex:
elu
----
.. autofunction:: paddle.v2.fluid.layers.elu
:noindex:
relu6
-----
.. autofunction:: paddle.v2.fluid.layers.relu6
:noindex:
pow
----
.. autofunction:: paddle.v2.fluid.layers.pow
:noindex:
hard_shrink
-----------
.. autofunction:: paddle.v2.fluid.layers.hard_shrink
:noindex:
thresholded_relu
----------------
.. autofunction:: paddle.v2.fluid.layers.thresholded_relu
:noindex:
hard_sigmoid
-------------
.. autofunction:: paddle.v2.fluid.layers.hard_sigmoid
:noindex:
swish
------
.. autofunction:: paddle.v2.fluid.layers.swish
:noindex:
......@@ -32,6 +32,16 @@ PaddlePaddle主要使用 `CMake <https://cmake.org>`_ 以及GCC, G++作为编译
pip install build/python/dist/*.whl
如果机器中已经安装过PaddlePaddle,有两种方法:
.. code-block:: bash
1. 先卸载之前的版本,再重新安装
pip uninstall paddlepaddle
pip install build/python/dist/*.whl
2. 直接升级到更新的版本
pip install build/python/dist/*.whl -U
.. _run_test:
......
......@@ -36,6 +36,16 @@ machine or copy it to the target machine.
pip install build/python/dist/*.whl
If the machine has installed PaddlePaddle before, there are two methods:
.. code-block:: bash
1. uninstall and reinstall
pip uninstall paddlepaddle
pip install build/python/dist/*.whl
2. upgrade directly
pip install build/python/dist/*.whl -U
.. _run_test:
......
......@@ -24,7 +24,7 @@
- `framework::OperatorWithKernel`:继承自OperatorBase,Op有计算函数,称作有Kernel。
- `class OpProtoAndCheckerMaker`:描述该Op的输入、输出、属性、注释,主要用于Python API接口生成
依据是否包含kernel,可以将Op分为两种:包含Kernel的Op和不包含kernel的Op,前者Op的定义继承自`OperatorBase`,后者继承自`OperatorWithKernel`。本教程主要介绍带Kernel的Op如何写,简单总结Op需要包含的内容如下:
依据是否包含kernel,可以将Op分为两种:包含Kernel的Op和不包含kernel的Op,前者Op的定义继承自`OperatorWithKernel`,后者继承自`OperatorBase`。本教程主要介绍带Kernel的Op如何写,简单总结Op需要包含的内容如下:
内容 | 定义位置
......
......@@ -31,15 +31,14 @@ static const platform::DeviceContext* GetDeviceContext(
}
}
Tensor* DeviceTransform(const Tensor& in, const platform::Place& dst_place) {
void DeviceTransform(const Tensor& in, const platform::Place& dst_place,
Tensor* out) {
VLOG(3) << "DeviceTransform in, src_place " << in.place()
<< " dst_place: " << dst_place;
Tensor* out = new Tensor();
auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
dev_ctx->Wait();
Copy(in, dst_place, *dev_ctx, out);
dev_ctx->Wait();
return out;
}
} // namespace framework
......
......@@ -21,7 +21,8 @@ limitations under the License. */
namespace paddle {
namespace framework {
Tensor* DeviceTransform(const Tensor& in, const platform::Place& dst_place);
void DeviceTransform(const Tensor& in, const platform::Place& dst_place,
Tensor* out);
} // namespace framework
} // namespace paddle
......@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once
#include <iostream>
#include <cctype>
#include <ostream>
#include "paddle/platform/enforce.h"
namespace paddle {
......@@ -27,12 +29,19 @@ enum class DataLayout {
};
inline DataLayout StringToDataLayout(const std::string& str) {
if (str == "NHWC" || str == "nhwc") {
std::string s(str);
for (size_t i = 0; i < s.size(); ++i) {
s[i] = toupper(s[i]);
}
if (s == "NHWC") {
return DataLayout::kNHWC;
} else if (str == "NCHW" || str == "nchw") {
} else if (s == "NCHW") {
return DataLayout::kNCHW;
} else if (s == "ANYLAYOUT") {
return DataLayout::kAnyLayout;
} else {
PADDLE_THROW("Unknown storage order string: %s", str);
PADDLE_THROW("Unknown storage order string: %s", s);
}
}
......@@ -49,7 +58,7 @@ inline std::string DataLayoutToString(const DataLayout& data_layout) {
}
}
inline std::ostream& operator<<(std::ostream& out, DataLayout l) {
inline std::ostream& operator<<(std::ostream& out, const DataLayout& l) {
out << DataLayoutToString(l);
return out;
}
......
......@@ -19,16 +19,14 @@ limitations under the License. */
namespace paddle {
namespace framework {
Tensor* DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor) {
Tensor* out = nullptr;
void DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor, Tensor* out) {
if (!platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_type.place_)) {
out = DeviceTransform(input_tensor, expected_kernel_type.place_);
DeviceTransform(input_tensor, expected_kernel_type.place_, out);
}
PADDLE_ENFORCE_NOT_NULL(out, "out should not be null");
return out;
}
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
......
......@@ -30,9 +30,9 @@ limitations under the License. */
namespace paddle {
namespace framework {
Tensor* DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor);
void DataTransform(const OpKernelType& expected_kernel_type,
const OpKernelType& kernel_type_for_var,
const Tensor& input_tensor, Tensor* out);
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
Variable& out_var);
......
......@@ -85,5 +85,10 @@ inline std::string KernelTypeToString(const OpKernelType& kernel_key) {
return stream.str();
}
inline bool TransFromNeeded(const OpKernelType& l, const OpKernelType& r) {
return (!platform::places_are_same_class(l.place_, r.place_)) ||
(l.data_type_ != r.data_type_) || (l.data_layout_ != r.data_layout_);
}
} // namespace framework
} // namespace paddle
......@@ -368,24 +368,6 @@ TEST(OperatorRegistrar, OpWithMultiKernel) {
// TODO(qiao) add priority back
// use all available kernels
paddle::framework::UseALL();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -10);
// remove cuda kernels
paddle::framework::UseCPU();
op->Run(scope, cpu_place);
EXPECT_EQ(op_test_value, -9);
// add cuda kernels
paddle::framework::UseCUDA();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -10);
// use cudnn kernel
paddle::framework::UseCUDNN();
op->Run(scope, cuda_place);
EXPECT_EQ(op_test_value, -20);
}
......@@ -29,52 +29,12 @@ DEFINE_bool(op_sync, false,
namespace paddle {
namespace framework {
std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
void UseCPU() {
kKernelPriority.clear();
/*Plain CPU*/
auto pair0 = std::make_tuple(platform::CPUPlace(), LibraryType::kPlain);
kKernelPriority.insert(kKernelPriority.begin(), pair0);
}
void UseMKLDNN() {
UseCPU();
#if PADDLE_WITH_MKLML
{
/*MKLDNN Kernel*/
auto pair0 = std::make_tuple(platform::CPUPlace(), LibraryType::kMKLDNN);
kKernelPriority.insert(kKernelPriority.begin(), pair0);
}
#endif
}
void UseCUDA() {
UseMKLDNN();
#if PADDLE_WITH_CUDA
/*Plain GPU*/
auto pair0 = std::make_tuple(platform::CUDAPlace(0), LibraryType::kPlain);
kKernelPriority.insert(kKernelPriority.begin(), pair0);
#endif
}
void UseCUDNN() {
UseCUDA();
#if PADDLE_WITH_CUDA
if (platform::dynload::HasCUDNN()) {
/*CUDNN Kernel*/
auto pair0 = std::make_tuple(platform::CUDAPlace(0), LibraryType::kCUDNN);
kKernelPriority.insert(kKernelPriority.begin(), pair0);
}
#endif
}
void UseALL() {
UseCPU();
UseMKLDNN();
UseCUDA();
UseCUDNN();
}
std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
std::make_tuple(platform::CUDAPlace(0), LibraryType::kCUDNN),
std::make_tuple(platform::CUDAPlace(0), LibraryType::kPlain),
std::make_tuple(platform::CPUPlace(), LibraryType::kMKLDNN),
std::make_tuple(platform::CPUPlace(), LibraryType::kPlain),
};
static DDim GetDims(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
......@@ -271,36 +231,33 @@ static bool VarIsTensor(const Variable* var) {
return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
}
static const Tensor* GetTensorFromVar(const Variable* var) {
const Tensor* t = nullptr;
static const Tensor* GetTensorFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) {
t = &(var->Get<LoDTensor>());
return var->GetMutable<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
t = &(var->Get<SelectedRows>().value());
return var->GetMutable<SelectedRows>()->mutable_value();
} else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var->Type().name());
}
return t;
}
static Tensor* GetMutableTensorFromVar(Variable* var) {
Tensor* t = nullptr;
if (var->IsType<LoDTensor>()) {
t = var->GetMutable<LoDTensor>();
return var->GetMutable<LoDTensor>();
} else if (var->IsType<SelectedRows>()) {
t = var->GetMutable<SelectedRows>()->mutable_value();
return var->GetMutable<SelectedRows>()->mutable_value();
} else {
PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
var->Type().name());
}
return t;
}
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name);
return var == nullptr ? nullptr : GetTensorFromVar(var);
return var == nullptr ? nullptr
: GetTensorFromVar(const_cast<Variable*>(var));
}
template <>
......@@ -343,6 +300,7 @@ bool OpSupportGPU(const std::string& op_type) {
auto it = all_kernels.find(op_type);
if (it == all_kernels.end()) {
// All control operator must support GPU
return true;
}
for (auto& kern_pair : it->second) {
......@@ -516,21 +474,17 @@ void OperatorWithKernel::Run(const Scope& scope,
}
ExecutionContext ctx(*this, scope, *dev_ctx);
auto expected_kernel_key = this->GetExpectedKernelType(ctx);
OpKernelMap& kernels = kernels_iter->second;
for (auto& candidate : kKernelPriority) {
auto candidate_key =
OpKernelType(expected_kernel_key.data_type_, std::get<0>(candidate),
expected_kernel_key.data_layout_, std::get<1>(candidate));
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the
// transform functions are ready.
if ((candidate_key == expected_kernel_key) ||
(kernels.count(candidate_key))) {
expected_kernel_key = candidate_key;
break;
}
}
// for (auto& candidate : kKernelPriority) {
// Do selection
// }
auto expected_kernel_key = this->GetExpectedKernelType(ctx);
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
......@@ -544,7 +498,7 @@ void OperatorWithKernel::Run(const Scope& scope,
if (tensor_in->IsInitialized()) {
auto kernel_type_for_var = this->GetKernelTypeForVar(
var_name_item.first, *tensor_in, expected_kernel_key);
if (kernel_type_for_var != expected_kernel_key) {
if (TransFromNeeded(kernel_type_for_var, expected_kernel_key)) {
auto out_var_names = OutputVars(true);
if (std::find(out_var_names.begin(), out_var_names.end(),
var_name) != out_var_names.end()) {
......@@ -553,11 +507,13 @@ void OperatorWithKernel::Run(const Scope& scope,
"does not support transform",
var_name);
}
VLOG(3) << "need to do transform for var " << var_name;
VLOG(3) << "Transform Variable " << var_name << " from "
<< kernel_type_for_var << " to " << expected_kernel_key;
auto* trans_var = new_scope.Var(var_name);
auto* out = DataTransform(expected_kernel_key, kernel_type_for_var,
*tensor_in);
CopyVariableWithTensor(*var, *out, *trans_var);
std::shared_ptr<Tensor> out(new Tensor);
DataTransform(expected_kernel_key, kernel_type_for_var, *tensor_in,
out.get());
CopyVariableWithTensor(*var, *(out.get()), *trans_var);
}
}
}
......
......@@ -54,33 +54,9 @@ constexpr char kGradVarSuffix[] = "@GRAD";
constexpr char kZeroVarSuffix[] = "@ZERO";
// define some kernel priority
/* Define multiple kernel type fallback order*/
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;
/**
* @brief Use cpu kernel only
*/
void UseCPU();
/**
* @brief Perfer MKLDNN kernel than Plain CPU kernel
*/
void UseMKLDNN();
/**
* @brief Perfer CUDA kernel than Plain CPU kernel
*/
void UseCUDA();
/**
* @brief Perfer cudnn kernel than Plain CUDA kernel
*/
void UseCUDNN();
/**
* @brief Use all available kernels
*/
void UseALL();
inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix;
}
......
......@@ -137,8 +137,6 @@ op_library(sum_op DEPS selected_rows_functor)
op_library(sgd_op DEPS selected_rows_functor)
op_library(print_op DEPS lod_tensor)
op_library(adagrad_op DEPS selected_rows_functor)
op_library(conv_op DEPS vol2col)
op_library(pool_op DEPS pooling)
op_library(maxout_op DEPS maxouting)
op_library(unpool_op DEPS unpooling)
op_library(pool_with_index_op DEPS pooling)
......@@ -149,12 +147,27 @@ op_library(max_sequence_len_op DEPS lod_rank_table)
op_library(sequence_conv_op DEPS context_project)
op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(conv_transpose_op DEPS vol2col)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding math_function)
op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor)
# Regist multiple Kernel to pybind
if (WITH_GPU)
op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS vol2col)
op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling)
op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
conv_transpose_cudnn_op.cu.cc DEPS vol2col)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d, CUDNN);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(pool2d, CUDNN);\n")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(conv2d_transpose, CUDNN);\n")
else()
op_library(conv_op SRCS conv_op.cc DEPS vol2col)
op_library(pool_op SRCS pool_op.cc DEPS pooling)
op_library(conv_transpose_op SRCS conv_transpose_op.cc DEPS vol2col)
endif()
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv_op.h"
namespace paddle {
namespace operators {
class CudnnConv2DOpMaker : public Conv2DOpMaker {
public:
CudnnConv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: Conv2DOpMaker(proto, op_checker) {
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardware. This size should be chosen carefully.")
.SetDefault(4096);
}
};
class CudnnConv3DOpMaker : public Conv3DOpMaker {
public:
CudnnConv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: Conv3DOpMaker(proto, op_checker) {
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardware. This size should be chosen carefully.")
.SetDefault(4096);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(conv2d_cudnn, ops::ConvOp, ops::CudnnConv2DOpMaker,
conv2d_cudnn_grad, ops::ConvOpGrad);
REGISTER_OP(conv3d_cudnn, ops::ConvOp, ops::CudnnConv3DOpMaker,
conv3d_cudnn_grad, ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(
conv2d_cudnn,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_cudnn_grad,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_cudnn,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_cudnn_grad,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -32,7 +32,7 @@ static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
static_cast<size_t>(1024) * 1024 * 1024;
template <typename T>
class CudnnConvOpKernel : public framework::OpKernel<T> {
class CUDNNConvOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......@@ -147,7 +147,7 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
};
template <typename T>
class CudnnConvGradOpKernel : public framework::OpKernel<T> {
class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......@@ -315,17 +315,16 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
// TODO(dzhwinter) : below register should be removed
REGISTER_OP_CUDA_KERNEL(conv2d_cudnn,
paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv2d_cudnn_grad,
paddle::operators::CudnnConvGradOpKernel<float>,
paddle::operators::CudnnConvGradOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv3d_cudnn,
paddle::operators::CudnnConvOpKernel<float>,
paddle::operators::CudnnConvOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv3d_cudnn_grad,
paddle::operators::CudnnConvGradOpKernel<float>,
paddle::operators::CudnnConvGradOpKernel<double>);
REGISTER_OP_KERNEL(conv2d, CUDNN, ::paddle::platform::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>,
paddle::operators::CUDNNConvOpKernel<double>);
REGISTER_OP_KERNEL(conv2d_grad, CUDNN, ::paddle::platform::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>);
REGISTER_OP_KERNEL(conv3d, CUDNN, ::paddle::platform::CUDAPlace,
paddle::operators::CUDNNConvOpKernel<float>,
paddle::operators::CUDNNConvOpKernel<double>);
REGISTER_OP_KERNEL(conv3d_grad, CUDNN, ::paddle::platform::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>);
......@@ -67,6 +67,23 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->ShareLoD("Input", "Output");
}
framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
} else {
library_ = framework::LibraryType::kPlain;
}
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout_, library_);
}
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
......@@ -108,6 +125,26 @@ Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
"dilations(h_dilation, w_dilation) of "
"convolution operator.")
.SetDefault({1, 1});
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
// TODO(dzhwinter): need to registered layout transform function
AddAttr<int>("workspace_size_MB",
"Only used in cudnn kernel. Need set use_cudnn to true."
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardware. This size should be chosen carefully.")
.SetDefault(4096);
AddComment(R"DOC(
Convolution Operator.
......@@ -181,6 +218,25 @@ Conv3DOpMaker::Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
"dilations(d_dilation, h_dilation, w_dilation) of "
"convolution operator.")
.SetDefault({1, 1, 1});
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
// TODO(dzhwinter): need to registered layout transform function
AddAttr<int>("workspace_size_MB",
"Only used in cudnn kernel. workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardware. This size should be chosen carefully.")
.SetDefault(4096);
AddComment(R"DOC(
Convolution3D Operator.
......@@ -224,6 +280,23 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
}
}
framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
} else {
library_ = framework::LibraryType::kPlain;
}
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout_, library_);
}
} // namespace operators
} // namespace paddle
......
......@@ -62,12 +62,20 @@ class ConvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class ConvOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
template <typename DeviceContext, typename T>
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv_transpose_op.h"
namespace paddle {
namespace operators {
class CudnnConv2DTransposeOpMaker : public Conv2DTransposeOpMaker {
public:
CudnnConv2DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: Conv2DTransposeOpMaker(proto, op_checker) {
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardward. This size should be carefully setted.")
.SetDefault(4096);
}
};
class CudnnConv3DTransposeOpMaker : public Conv3DTransposeOpMaker {
public:
CudnnConv3DTransposeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: Conv3DTransposeOpMaker(proto, op_checker) {
AddAttr<int>("workspace_size_MB",
"workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardward. This size should be carefully setted.")
.SetDefault(4096);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp,
ops::CudnnConv2DTransposeOpMaker, conv2d_transpose_cudnn_grad,
ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP(conv3d_transpose_cudnn, ops::ConvTransposeOp,
ops::CudnnConv3DTransposeOpMaker, conv3d_transpose_cudnn_grad,
ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_transpose_cudnn_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
......@@ -28,10 +28,10 @@ using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout;
static constexpr size_t kConvCudnnWorkspaceLimitBytes = 1024 * 1024 * 1024;
static constexpr size_t kConvCUDNNWorkspaceLimitBytes = 1024 * 1024 * 1024;
template <typename T>
class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......@@ -77,7 +77,7 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv workspace ---------------------
void* cudnn_workspace = nullptr;
size_t workspace_size_in_bytes; // final workspace to allocate.
size_t workspace_size_limit = kConvCudnnWorkspaceLimitBytes;
size_t workspace_size_limit = kConvCUDNNWorkspaceLimitBytes;
if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024;
}
......@@ -116,7 +116,7 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
};
template <typename T>
class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......@@ -161,7 +161,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
cudnnConvolutionBwdFilterAlgo_t filter_algo;
size_t bwd_filter_ws_size, fwd_ws_size;
size_t workspace_size_in_bytes = 0;
size_t workspace_size_limit = kConvCudnnWorkspaceLimitBytes;
size_t workspace_size_limit = kConvCUDNNWorkspaceLimitBytes;
if (user_workspace_size > 0) {
workspace_size_limit = user_workspace_size * 1024 * 1024;
}
......@@ -236,16 +236,16 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(conv2d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>,
ops::CudnnConvTransposeOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv2d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>,
ops::CudnnConvTransposeGradOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv3d_transpose_cudnn,
ops::CudnnConvTransposeOpKernel<float>,
ops::CudnnConvTransposeOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(conv3d_transpose_cudnn_grad,
ops::CudnnConvTransposeGradOpKernel<float>,
ops::CudnnConvTransposeGradOpKernel<double>);
REGISTER_OP_KERNEL(conv2d_transpose, CUDNN, ::paddle::platform::CUDAPlace,
ops::CUDNNConvTransposeOpKernel<float>,
ops::CUDNNConvTransposeOpKernel<double>);
REGISTER_OP_KERNEL(conv2d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace,
ops::CUDNNConvTransposeGradOpKernel<float>,
ops::CUDNNConvTransposeGradOpKernel<double>);
REGISTER_OP_KERNEL(conv3d_transpose, CUDNN, ::paddle::platform::CUDAPlace,
ops::CUDNNConvTransposeOpKernel<float>,
ops::CUDNNConvTransposeOpKernel<double>);
REGISTER_OP_KERNEL(conv3d_transpose_grad, CUDNN, ::paddle::platform::CUDAPlace,
ops::CUDNNConvTransposeGradOpKernel<float>,
ops::CUDNNConvTransposeGradOpKernel<double>);
......@@ -58,6 +58,23 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
} else {
library_ = framework::LibraryType::kPlain;
}
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout_, library_);
}
Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto,
OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
......@@ -94,6 +111,25 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(OpProto* proto,
"(vector<int> default:{0, 0}), the paddings(h_pad, w_pad) of convolution "
"transpose operator.")
.SetDefault({0, 0});
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
// TODO(dzhwinter): need to registered layout transform function
AddAttr<int>("workspace_size_MB",
"Used in cudnn kernel only. workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardward. This size should be carefully setted.")
.SetDefault(4096);
AddComment(R"DOC(
Convolution2D Transpose Operator.
......@@ -163,6 +199,25 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(OpProto* proto,
"(vector<int> default:{0, 0, 0}), paddings(d_pad, "
"h_pad, w_pad) of convolution transpose operator.")
.SetDefault({0, 0, 0});
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
// TODO(dzhwinter): need to registered layout transform function
AddAttr<int>("workspace_size_MB",
"Used in cudnn kernel only. workspace size for cudnn, in MB, "
"workspace is a section of GPU memory which will be "
"allocated/freed each time the operator runs, larger "
"workspace size can increase performance but also requires "
"better hardward. This size should be carefully setted.")
.SetDefault(4096);
AddComment(R"DOC(
Convolution3D Transpose Operator.
......@@ -205,6 +260,23 @@ void ConvTransposeOpGrad::InferShape(framework::InferShapeContext* ctx) const {
}
}
framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
} else {
library_ = framework::LibraryType::kPlain;
}
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
layout_, library_);
}
} // namespace operators
} // namespace paddle
......
......@@ -42,12 +42,20 @@ class ConvTransposeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class ConvTransposeOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
template <typename DeviceContext, typename T>
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence2batch.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_cudnn_op.h"
namespace ops = paddle::operators;
REGISTER_OP(pool2d_cudnn, ops::PoolOp, ops::Pool2dOpMaker, pool2d_cudnn_grad,
ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(
pool2d_cudnn, ops::PoolKernel<paddle::platform::CPUDeviceContext, float>,
ops::PoolKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
pool2d_cudnn_grad,
ops::PoolGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::PoolGradKernel<paddle::platform::CPUDeviceContext, double>)
REGISTER_OP(pool3d_cudnn, ops::PoolOp, ops::Pool3dOpMaker, pool3d_cudnn_grad,
ops::PoolOpGrad);
REGISTER_OP_CPU_KERNEL(
pool3d_cudnn, ops::PoolKernel<paddle::platform::CPUDeviceContext, float>,
ops::PoolKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
pool3d_cudnn_grad,
ops::PoolGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::PoolGradKernel<paddle::platform::CPUDeviceContext, double>)
......@@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_cudnn_op.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/pool_op.h"
#include "paddle/platform/cudnn_helper.h"
namespace paddle {
......@@ -25,7 +26,7 @@ using DataLayout = platform::DataLayout;
using PoolingMode = platform::PoolingMode;
template <typename T>
class PoolCudnnOpKernel : public framework::OpKernel<T> {
class PoolCUDNNOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......@@ -86,7 +87,7 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
};
template <typename T>
class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......@@ -162,12 +163,16 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel<float>,
ops::PoolCudnnOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>,
ops::PoolCudnnGradOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(pool3d_cudnn, ops::PoolCudnnOpKernel<float>,
ops::PoolCudnnOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(pool3d_cudnn_grad, ops::PoolCudnnGradOpKernel<float>,
ops::PoolCudnnGradOpKernel<double>);
REGISTER_OP_KERNEL(pool2d, CUDNN, ::paddle::platform::CUDAPlace,
ops::PoolCUDNNOpKernel<float>,
ops::PoolCUDNNOpKernel<double>);
REGISTER_OP_KERNEL(pool2d_grad, CUDNN, ::paddle::platform::CUDAPlace,
ops::PoolCUDNNGradOpKernel<float>,
ops::PoolCUDNNGradOpKernel<double>);
REGISTER_OP_KERNEL(pool3d, CUDNN, ::paddle::platform::CUDAPlace,
ops::PoolCUDNNOpKernel<float>,
ops::PoolCUDNNOpKernel<double>);
REGISTER_OP_KERNEL(pool3d_grad, CUDNN, ::paddle::platform::CUDAPlace,
ops::PoolCUDNNGradOpKernel<float>,
ops::PoolCUDNNGradOpKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/pool_op.h"
namespace paddle {
namespace operators {} // namespace operators
} // namespace paddle
......@@ -61,6 +61,23 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
ctx->ShareLoD("X", "Out");
}
framework::OpKernelType PoolOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
} else {
library_ = framework::LibraryType::kPlain;
}
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout_, library_);
}
void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
......@@ -68,6 +85,23 @@ void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
framework::LibraryType library_;
if (use_cudnn) {
library_ = framework::LibraryType::kCUDNN;
} else {
library_ = framework::LibraryType::kPlain;
}
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
layout_, library_);
}
Pool2dOpMaker::Pool2dOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
......@@ -101,15 +135,27 @@ Pool2dOpMaker::Pool2dOpMaker(OpProto *proto, OpAttrChecker *op_checker)
AddAttr<std::vector<int>>("strides",
"(vector<int>, default {1, 1}), strides(height, "
"width) of pooling operator.")
.SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently,
.SetDefault({1, 1});
// TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<std::vector<int>>(
"paddings",
"(vector<int>, default {0,0}), paddings(height, width) of pooling "
"operator."
"If global_pooling = true, paddings and ksize will be ignored.")
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
.SetDefault({0, 0});
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
// TODO(dzhwinter): need to registered layout transform function
AddComment(R"DOC(
Pool2d Operator.
......@@ -182,6 +228,19 @@ Pool3dOpMaker::Pool3dOpMaker(OpProto *proto, OpAttrChecker *op_checker)
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"use_cudnn",
"(bool, default false) Only used in cudnn kernel, need install cudnn")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
// TODO(dzhwinter): need to registered layout transform function
AddComment(R"DOC(
Pool3d Operator.
......
......@@ -29,6 +29,10 @@ class PoolOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class PoolOpGrad : public framework::OperatorWithKernel {
......@@ -36,6 +40,10 @@ class PoolOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -16,12 +16,17 @@
#include <ctime>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/variable.h"
namespace paddle {
namespace operators {
#define CLOG std::cout
const std::string kForward = "FORWARD";
const std::string kBackward = "BACKWARD";
const std::string kBoth = "BOTH";
struct Formater {
std::string message;
std::string name;
......@@ -122,40 +127,77 @@ class TensorPrintOp : public framework::OperatorBase {
TensorPrintOp(const TensorPrintOp& o)
: framework::OperatorBase(
static_cast<const framework::OperatorBase&>(o)) {
PADDLE_THROW("Not implemented");
PADDLE_THROW("Not implemented.");
}
void Run(const framework::Scope& scope,
const platform::Place& place) const override {
// Only run the `first_n` times.
const framework::Variable* in_var_ptr = nullptr;
std::string phase = kForward;
std::string printed_var_name = "";
auto& inputs = Inputs();
if (inputs.find("In") != inputs.end() && !Inputs("In").empty()) {
in_var_ptr = scope.FindVar(Input("In"));
printed_var_name = Inputs("In").front();
} else if (inputs.find("In@GRAD") != inputs.end() &&
!Inputs("In@GRAD").empty()) {
in_var_ptr = scope.FindVar(Input("In@GRAD"));
printed_var_name = Inputs("In@GRAD").front();
phase = kBackward;
} else {
PADDLE_THROW("Unknown phase, should be forward or backward.");
}
PADDLE_ENFORCE_NOT_NULL(in_var_ptr);
auto& in_tensor = in_var_ptr->Get<framework::LoDTensor>();
auto* out_var_ptr = scope.FindVar(Output("Out"));
auto& out_tensor = *out_var_ptr->GetMutable<framework::LoDTensor>();
// Just copy data from input tensor to output tensor
// output tensor share same memory with input tensor
out_tensor.ShareDataWith(in_tensor);
out_tensor.set_lod(in_tensor.lod());
std::string print_phase = Attr<std::string>("print_phase");
if (print_phase != phase && print_phase != kBoth) {
return;
}
int first_n = Attr<int>("first_n");
if (first_n > 0 && ++times_ > first_n) return;
PADDLE_ENFORCE(!Inputs("input").empty(), "input should be set");
auto* input_var = scope.FindVar(Input("input"));
PADDLE_ENFORCE_NOT_NULL(input_var);
auto& tensor = input_var->Get<framework::LoDTensor>();
framework::LoDTensor printed_tensor;
printed_tensor.set_lod(in_tensor.lod());
printed_tensor.Resize(in_tensor.dims());
// TODO(ChunweiYan) support GPU
PADDLE_ENFORCE(platform::is_cpu_place(tensor.place()));
if (platform::is_cpu_place(in_tensor.place())) {
printed_tensor.ShareDataWith(in_tensor);
} else {
// copy data to cpu to print
platform::CPUPlace place;
framework::Copy(in_tensor, place, &printed_tensor);
}
Formater formater;
if (Attr<bool>("print_tensor_name")) {
formater.name = Inputs("input").front();
formater.name = printed_var_name;
}
if (Attr<bool>("print_tensor_type")) {
formater.dtype = tensor.type();
formater.dtype = printed_tensor.type();
}
if (Attr<bool>("print_tensor_shape")) {
formater.dims.assign(tensor.dims()[0],
tensor.dims()[tensor.dims().size() - 1]);
auto& dims = printed_tensor.dims();
formater.dims.resize(dims.size());
for (int i = 0; i < dims.size(); ++i) formater.dims[i] = dims[i];
}
if (Attr<bool>("print_tensor_lod")) {
formater.lod = tensor.lod();
formater.lod = printed_tensor.lod();
}
formater.summarize = Attr<int>("summarize");
formater.data = (void*)tensor.data<void>();
formater(tensor.numel());
formater.data = (void*)printed_tensor.data<void>();
formater(printed_tensor.numel());
}
private:
......@@ -166,27 +208,46 @@ class PrintOpProtoAndCheckMaker : public framework::OpProtoAndCheckerMaker {
public:
PrintOpProtoAndCheckMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("input", "the tensor that will be displayed.");
AddInput("In", "Input tensor to be displayed.");
AddAttr<int>("first_n", "Only log `first_n` number of times.");
AddAttr<std::string>("message", "A string message to print as a prefix.");
AddAttr<int>("summarize", "Print this number of elements in the tensor.");
AddAttr<int>("summarize", "Number of elements printed.");
AddAttr<bool>("print_tensor_name", "Whether to print the tensor name.");
AddAttr<bool>("print_tensor_type", "Whether to print the tensor's dtype.");
AddAttr<bool>("print_tensor_shape", "Whether to print the tensor's shape.");
AddAttr<bool>("print_tensor_lod", "Whether to print the tensor's lod.");
AddAttr<std::string>(
"print_phase",
"(string, default 'BOTH') Which phase to display including 'FORWARD' "
"'BACKWARD' and 'BOTH'.")
.SetDefault(kBoth)
.InEnum({kForward, kBackward, kBoth});
AddOutput("Out", "Output tensor with same data as input tensor.");
AddComment(R"DOC(
Creates a print op that will print when a tensor is accessed.
Creates a print op that will print when a tensor is accessed.
Wraps the tensor passed in so that whenever that a tensor is accessed,
the message `message` is printed, along with the current value of the
tensor `t`.)DOC");
Wraps the tensor passed in so that whenever that a tensor is accessed,
the message `message` is printed, along with the current value of the
tensor `t`.)DOC");
}
};
class InferShape : public framework::InferShapeBase {
class InferShapeForward : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* context) const override {
PADDLE_ENFORCE(context->HasInput("input"), "input should be set");
PADDLE_ENFORCE(context->HasInput("In"), "Input(In) should not be null.");
context->ShareLoD("In", /*->*/ "Out");
context->SetOutputDim("Out", context->GetInputDim("In"));
}
};
class InferShapeBackward : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* context) const override {
PADDLE_ENFORCE(context->HasInput("In@GRAD"),
"Input(In@GRAD) should not be null.");
context->ShareLoD("In@GRAD", /*->*/ "Out");
context->SetOutputDim("Out", context->GetInputDim("In@GRAD"));
}
};
......@@ -196,11 +257,27 @@ class InferVarType : public framework::VarTypeInference {
framework::BlockDesc* block) const override {}
};
class PrintOpProtoAndCheckGradOpMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op_desc_ptr = new framework::OpDesc();
op_desc_ptr->SetType("print_grad");
op_desc_ptr->SetInput("In@GRAD", OutputGrad("Out"));
op_desc_ptr->SetOutput("Out", InputGrad("In"));
op_desc_ptr->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(op_desc_ptr);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(print, paddle::operators::TensorPrintOp,
paddle::operators::PrintOpProtoAndCheckMaker,
paddle::operators::InferShape,
paddle::operators::InferVarType,
paddle::framework::EmptyGradOpMaker);
namespace ops = paddle::operators;
REGISTER_OPERATOR(print, ops::TensorPrintOp, ops::PrintOpProtoAndCheckMaker,
ops::PrintOpProtoAndCheckGradOpMaker, ops::InferShapeForward,
ops::InferVarType);
REGISTER_OPERATOR(print_grad, ops::TensorPrintOp, ops::InferShapeBackward);
......@@ -26,22 +26,44 @@ class ReorderLoDTensorByRankTableOpProtoMaker
ReorderLoDTensorByRankTableOpProtoMaker(OpProto *proto,
OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) the input lod tensor need to be reordered.");
AddInput("X",
"(LoDTensor), the input lod tensor to be reordered according to "
"Input(RankTable).");
AddInput("RankTable",
"(LoDRankTable) the rank table that input need follow");
AddOutput("Out", "(LoDTensor) reordered lod tensor");
AddComment(R"DOC(ReorderLoDTensorByRankTable
"(LoDRankTable), the rank table according to which Input(X) is "
"reordered.");
AddOutput("Out", "(LoDTensor), the reordered lod tensor.");
AddComment(R"DOC(ReorderLoDTensorByRankTable operator.
Reorder the input X by the rank of `RankTable`. If `RankTable` is ordered by
index [3, 0, 2, 1]. Input X will reorder its sequence, the third sequence of
X will be the first sequence of Output.
NOTE: The RankTable does not need to be calculated by X.
Input(X) is a batch of sequences. Input(RankTable) stores new orders of the
input sequence batch. The reorder_lod_tensor_by_rank operator reorders the
Input(X) according to the information provided by Input(RankTable).
For example:
The X = [Seq0, Seq1, Seq2, Seq3]. The indices of RankTable are [3, 0, 2, 1].
The Out = [Seq3, Seq0, Seq2, Seq1] with correct LoD information.
If the indices stored in the Input(RankTable) are [3, 0, 2, 1], the
Input(X) will be reordered that the fourth sequence in Input(X) will become the
first one, and then followed by the original first, third, and the second one.
This is:
X = [Seq0, Seq1, Seq2, Seq3]. The indices in RankTable are [3, 0, 2, 1].
Out = [Seq3, Seq0, Seq2, Seq1] with a new LoD information.
If the LoD information of Input(X) is empty, this means Input(X) is not sequence
data. This is also identical to a batch of sequences where each sequence has a
fixed length 1. In this case, the reorder_lod_tensor_by_rank operator reorders
each slice of Input(X) along the first axis according to Input(RankTable).
This is:
X = [Slice0, Slice1, Slice2, Slice3] and its LoD information is empty. The
indices in RankTable are [3, 0, 2, 1].
Out = [Slice3, Slice0, Slice2, Slice1] with no LoD information is appended.
NOTE: This operator sorts Input(X) according to a given LoDRankTable which does
not need to be calculated according to Input(X). It can be calculated according
to another different sequence, and then this operator sorts Input(X) according
to the given LoDRankTable.
)DOC");
}
};
......
......@@ -45,7 +45,7 @@ class ShrinkRNNMemoryOp : public ArrayOp {
rank_items.begin();
auto *out_var = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set");
PADDLE_ENFORCE(out_var != nullptr, "Output(Out) must be set.");
auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>();
size_t height = dst_num_rows;
......@@ -76,15 +76,17 @@ class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor) The step index. The RNN step memory 'X' will be "
"shrinked to match the size of the input of the index'th step.");
AddOutput("Out", "(LoDTensor) The shrinked RNN step memory.");
AddComment(
R"DOC(
In dynamic RNN, we are able to handle sequences of different lengths.
Because of the multiple lengths, the size of each step input can be
different, which may lead to a mismatching between the input of
the current step and the memory generated by the previous one. This
operator shrinks memory according to the size of the next step input,
to make sure that they can match each other.
)DOC");
AddComment(R"DOC(
This operator is used to shrink output batch of memory defined in dynamic RNN.
Dynamic RNN is able to handle variable-length sequences, in which, sequences in
a mini-batch are sorted by their lengths first. After that, the longest sequence
becomes the first one in the sorted batch, followed by the second longest, the
third longest, and so on. Dynamic RNN then slices a batch input timestep by
timestep from the sorted input. Once any sequence in the input batch reaches its
end, memory defined in dynamicRNN has to shrink its outputs to adapt to the input
batch size for the next time step.
)DOC");
}
};
......@@ -136,6 +138,7 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
math::set_constant(dev_ctx, &rest_tensor, 0.0f);
}
}
dx_tensor.set_lod(x_tensor.lod());
}
};
......
......@@ -121,8 +121,8 @@ class WhileGradOp : public framework::OperatorBase {
for (size_t i = 0; i < outside_og_names.size(); ++i) {
auto outside_og_name = outside_og_names[i];
auto inside_og_name = inside_og_names[i];
VLOG(10) << "Linking outside " << outside_og_name << " --> inside "
<< inside_og_name;
VLOG(8) << "Linking outside " << outside_og_name << " --> inside "
<< inside_og_name;
auto &og_outside =
detail::Ref(scope.FindVar(outside_og_name),
"Cannot find Outside Gradient %s", outside_og_name);
......@@ -141,11 +141,11 @@ class WhileGradOp : public framework::OperatorBase {
auto &outside_array = og_outside.Get<framework::LoDTensorArray>();
auto &inside_array =
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>());
VLOG(10) << outside_og_name << " size = " << outside_array.size();
VLOG(8) << outside_og_name << " size = " << outside_array.size();
inside_array.resize(outside_array.size());
for (size_t j = 0; j < inside_array.size(); ++j) {
VLOG(10) << j << " " << outside_array[j].numel();
VLOG(8) << j << " " << outside_array[j].numel();
if (outside_array[j].numel() != 0) {
inside_array[j].set_lod(outside_array[j].lod());
inside_array[j].ShareDataWith(outside_array[j]);
......@@ -187,10 +187,14 @@ class WhileGradOp : public framework::OperatorBase {
attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f;
auto var_name = pg_names[param_id];
auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", framework::VariableNameMap{},
{{"Out", {pg_names[param_id]}}}, attrs);
{{"Out", {var_name}}}, attrs);
zero_op->Run(scope, dev_place);
scope.FindVar(var_name)
->GetMutable<framework::LoDTensor>()
->set_lod(inside_tensor.lod());
}
}
......@@ -231,7 +235,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
for (auto &each_ig : igs) {
if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) {
VLOG(10) << "Ignore " << each_ig;
VLOG(8) << "Ignore " << each_ig;
each_ig = framework::kEmptyVarName;
}
}
......
......@@ -44,7 +44,7 @@ CUDNN_DNN_ROUTINE_EACH_R7(DEFINE_WRAP);
#ifdef PADDLE_USE_DSO
bool HasCUDNN() {
std::call_once(cudnn_dso_flag, GetCudnnDsoHandle, &cudnn_dso_handle);
std::call_once(cudnn_dso_flag, GetCUDNNDsoHandle, &cudnn_dso_handle);
return cudnn_dso_handle != nullptr;
}
......
......@@ -36,7 +36,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
auto operator()(Args... args) -> decltype(__name(args...)) { \
using cudnn_func = decltype(__name(args...)) (*)(Args...); \
std::call_once(cudnn_dso_flag, \
paddle::platform::dynload::GetCudnnDsoHandle, \
paddle::platform::dynload::GetCUDNNDsoHandle, \
&cudnn_dso_handle); \
EnforceCUDNNLoaded(#__name); \
void* p_##__name = dlsym(cudnn_dso_handle, #__name); \
......
......@@ -134,7 +134,7 @@ void GetCublasDsoHandle(void** dso_handle) {
#endif
}
void GetCudnnDsoHandle(void** dso_handle) {
void GetCUDNNDsoHandle(void** dso_handle) {
#if defined(__APPLE__) || defined(__OSX__)
GetDsoHandleFromSearchPath(FLAGS_cudnn_dir, "libcudnn.dylib", dso_handle,
false);
......
......@@ -32,7 +32,7 @@ void GetCublasDsoHandle(void** dso_handle);
* @param **dso_handle dso handler
*
*/
void GetCudnnDsoHandle(void** dso_handle);
void GetCUDNNDsoHandle(void** dso_handle);
/**
* @brief load the DSO of CURAND
......
......@@ -430,13 +430,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def("init_glog", framework::InitGLOG);
m.def("init_devices", &framework::InitDevices);
m.def("use_cpu", framework::UseCPU);
m.def("use_mkldnn", framework::UseMKLDNN);
m.def("use_cuda", framework::UseCUDA);
m.def("use_cudnn", framework::UseCUDNN);
m.def("use_all", framework::UseALL);
m.def("is_compile_gpu", IsCompileGPU);
m.def("set_feed_variable", framework::SetFeedVariable);
m.def("get_fetch_variable", framework::GetFetchVariable);
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#include <string>
#include "paddle/framework/tensor.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/memory/memcpy.h"
#include "paddle/platform/device_context.h"
#include "pybind11/numpy.h"
......@@ -97,14 +97,27 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
template <typename T>
T TensorGetElement(framework::Tensor &self, size_t offset) {
PADDLE_ENFORCE(platform::is_cpu_place(self.place()));
return self.data<T>()[offset];
if (platform::is_cpu_place(self.place())) {
return self.data<T>()[offset];
} else {
std::shared_ptr<framework::Tensor> dst(new framework::Tensor);
framework::Copy(self, platform::CPUPlace(), dst.get());
return dst->data<T>()[offset];
}
}
// TODO(dzhwinter) : fix the redundent Tensor allocate and free
template <typename T>
void TensorSetElement(framework::Tensor &self, size_t offset, T elem) {
PADDLE_ENFORCE(platform::is_cpu_place(self.place()));
self.data<T>()[offset] = elem;
if (platform::is_gpu_place(self.place())) {
std::shared_ptr<framework::Tensor> dst(new framework::Tensor);
framework::Copy(self, platform::CPUPlace(), dst.get());
dst->data<T>()[offset] = elem;
framework::Copy(*dst.get(), self.place(), &self);
} else if (platform::is_cpu_place(self.place())) {
self.data<T>()[offset] = elem;
}
}
template <typename T>
......
......@@ -18,14 +18,29 @@ from param_attr import ParamAttr
from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, CUDAPlace
from distribute_transpiler import DistributeTranspiler
from distribute_transpiler_simple import SimpleDistributeTranspiler
import clip
from memory_optimization_transpiler import memory_optimize
Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + [
'io', 'initializer', 'layers', 'nets', 'optimizer', 'backward',
'regularizer', 'LoDTensor', 'CPUPlace', 'CUDAPlace', 'Tensor', 'ParamAttr'
'DataFeeder', 'clip', 'DistributeTranspiler', 'memory_optimize'
'io',
'initializer',
'layers',
'nets',
'optimizer',
'backward',
'regularizer',
'LoDTensor',
'CPUPlace',
'CUDAPlace',
'Tensor',
'ParamAttr'
'DataFeeder',
'clip',
'SimpleDistributeTranspiler',
'DistributeTranspiler',
'memory_optimize',
]
......
......@@ -3,7 +3,10 @@ from . import core
import collections
import copy
__all__ = ['append_backward', 'calc_gradient']
__all__ = [
'append_backward',
'calc_gradient',
]
def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None):
......
......@@ -3,7 +3,10 @@ import layers
from . import core
__all__ = [
'GradientClipByValue', 'append_gradient_clip_ops', 'error_clip_callback'
'GradientClipByValue',
'ErrorClipByValue',
'append_gradient_clip_ops',
'error_clip_callback',
]
......@@ -23,12 +26,12 @@ class ErrorClipByValue(BaseErrorClipAttr):
self.min = min
def append_clip_op(self, block, grad_name):
block.append_op(
type="clip",
inputs={"X": grad_name},
outputs={"Out": grad_name},
attrs={"min": self.min,
"max": self.max})
clip_op_desc = block.desc.append_op()
clip_op_desc.set_type("clip")
clip_op_desc.set_input("X", [grad_name])
clip_op_desc.set_output("Out", [grad_name])
clip_op_desc.set_attr("min", self.min)
clip_op_desc.set_attr("max", self.max)
def error_clip_callback(block, context):
......@@ -39,6 +42,11 @@ def error_clip_callback(block, context):
op_desc.output_arg_names()):
fwd_var = block.var_recursive(grad_to_var[grad_n])
error_clip = getattr(fwd_var, "error_clip", None)
if not (error_clip is None or isinstance(error_clip,
BaseErrorClipAttr)):
raise TypeError(
"Variable's error_clip should be an instance of BaseErrorClipAttr or None."
)
if error_clip is not None:
error_clip.append_clip_op(block, grad_n)
......
"""
Default scope function.
`Paddle` manages Scope as programming language's scope. It just a
thread-local stack of Scope. Top of that stack is current scope, the bottom
of that stack is all scopes' parent.
`Paddle` manages Scope as programming language's scope. It just a
thread-local stack of Scope. Top of that stack is current scope, the bottom
of that stack is all scopes' parent.
Invoking `var/find_var` can `new/find` variable in current scope.
Invoking `enter_local_scope/leave_local_scope` can create or destroy local
scope.
Invoking `var/find_var` can `new/find` variable in current scope.
Invoking `enter_local_scope/leave_local_scope` can create or destroy local
scope.
A `scoped_function` will take a `function` as input. That function will be
invoked in a new local scope.
A `scoped_function` will take a `function` as input. That function will be
invoked in a new local scope.
"""
import paddle.v2.fluid.core
......@@ -19,8 +19,12 @@ import threading
__tl_scope__ = threading.local()
__all__ = [
'get_cur_scope', 'enter_local_scope', 'leave_local_scope', 'var',
'find_var', 'scoped_function'
'get_cur_scope',
'enter_local_scope',
'leave_local_scope',
'var',
'find_var',
'scoped_function',
]
......@@ -71,7 +75,7 @@ def find_var(name):
def scoped_function(func):
"""
invoke `func` in new scope.
:param func: a callable function that will be run in new scope.
:type func: callable
"""
......
import framework
from framework import Program, default_main_program, Parameter, Variable
import optimizer
from layer_helper import LayerHelper
def hash_name_to_server(params_grads, pserver_endpoints):
"""
:param param_grads:
:return: a map of pserver endpoint ->
params -> [param list]
grads -> [grad list]
"""
def _hash_param(param_name, total):
return hash(param_name) % total
param_grad_map = dict()
for param, grad in params_grads:
if param.trainable is True and grad is not None:
server_id = _hash_param(param.name, len(pserver_endpoints))
server_for_param = pserver_endpoints[server_id]
if not param_grad_map.has_key(server_for_param):
param_grad_map[server_for_param] = {"params": [], "grads": []}
param_grad_map[server_for_param]["params"].append(param)
param_grad_map[server_for_param]["grads"].append(grad)
return param_grad_map
def round_robin(params_grads, pserver_endpoints):
assert (len(params_grads) > len(pserver_endpoints))
param_grad_map = dict()
pserver_idx = 0
for param, grad in params_grads:
if param.trainable is True:
server_for_param = pserver_endpoints[pserver_idx]
if not param_grad_map.has_key(server_for_param):
param_grad_map[server_for_param] = {"params": [], "grads": []}
param_grad_map[server_for_param]["params"].append(param)
param_grad_map[server_for_param]["grads"].append(grad)
pserver_idx += 1
if pserver_idx >= len(pserver_endpoints):
pserver_idx = 0
return param_grad_map
class SimpleDistributeTranspiler:
def transpile(self,
optimize_ops,
params_grads,
program=None,
pservers="127.0.0.1:6174",
trainers=1,
split_method=round_robin):
"""
Transpile the program to a distributed data-parallelism programs.
The main_program will be transform to use a remote parameter server
to do parameter optimization. And the optimization graph will be put
in to a parameter server program.
Use different methods to split trainable varialbles to different
parameter servers.
Example to run:
exe = fluid.Executor(place)
t = fluid.DistributeTranspiler()
t.transpile(optimize_ops, params_grads, pservers="127.0.0.1:6174", trainers=1)
pserver_endpoint = os.getenv("PSERVER")
if pserver_endpoint:
pserver_prog = t.get_pserver_program(pserver_endpoint, optimize_ops)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
else:
feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
exe.run(fluid.default_startup_program())
for pass_id in range(PASS_NUM):
...
:param optimize_ops: op list of optimization, should be the
return value of Optimizer.minimize
:type optimize_ops: list
:param program: program to optimize, default default_main_program
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
:type pservers: string
:return: return a list of programs
"""
if program is None:
program = default_main_program()
self.program = program
self.trainers = trainers
self.optimize_ops = optimize_ops
self._optimize_distributed(
optimize_ops,
program,
params_grads,
pservers=pservers,
trainers=trainers,
split_method=split_method)
def _clone_param(self, block, v):
assert isinstance(v, Parameter)
new_p = Parameter(
block=block,
shape=v.shape,
dtype=v.dtype,
type=v.type,
lod_level=v.lod_level,
stop_gradient=v.stop_gradient,
trainable=v.trainable,
optimize_attr=v.optimize_attr,
regularizer=v.regularizer,
name=v.name)
block.vars[new_p.name] = new_p
def _clone_var(self, block, var):
assert isinstance(var, Variable)
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=var.persistable)
def _optimize_distributed(self, optimize_ops, program, params_and_grads,
**kwargs):
if kwargs.has_key("split_method"):
split_method = kwargs["split_method"]
else:
split_method = round_robin
assert (callable(split_method))
pserver_endpoints = kwargs["pservers"].split(",")
self.param_grad_map = split_method(params_and_grads, pserver_endpoints)
send_op_ordered_inputs = []
send_op_ordered_outputs = []
epmap = []
for ep, v in self.param_grad_map.iteritems():
send_op_ordered_inputs.extend(v["grads"])
send_op_ordered_outputs.extend(v["params"])
for i in v["grads"]:
epmap.append(ep)
send_op = program.global_block().append_op(
type="send",
inputs={"X": send_op_ordered_inputs
}, # inputs is a list of tensors to be send
outputs={"Out": send_op_ordered_outputs},
attrs={"endpoints": pserver_endpoints,
"epmap": epmap})
def get_trainer_program(self):
# remove optimize ops and add a send op to main_program
self.program.global_block().delete_ops(self.optimize_ops)
return self.program
def _create_var_for_trainers(self, block, var, trainers):
var_list = []
for i in xrange(trainers):
var_each = block.create_var(
name="%s.trainer_%d" % (var.name, i),
psersistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
var_list.append(var_each)
return var_list
def get_pserver_program(self, endpoint, optimize_ops):
pserver_program = Program()
for v in self.param_grad_map[endpoint]["params"]:
self._clone_param(pserver_program.global_block(), v)
optimize_sub_program = Program()
grad_var_names = [
var.name for var in self.param_grad_map[endpoint]["grads"]
]
for opt_op in optimize_ops:
for _, var in opt_op.inputs.iteritems():
# NOTE: append operators to merge gradients from multiple
# trainers. If trainers == 1, this is not needed.
if self.trainers > 1 and var.name in grad_var_names:
vars2merge = self._create_var_for_trainers(
optimize_sub_program.global_block(), var, self.trainers)
merged_var = optimize_sub_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
optimize_sub_program.global_block().append_op(
type="sum",
inputs={"X": vars2merge},
outputs={"Out": merged_var})
optimize_sub_program.global_block().append_op(
type="scale",
inputs={"X": merged_var},
outputs={"Out": merged_var},
attrs={"scale": 1.0 / float(self.trainers)})
else:
optimize_sub_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
if opt_op.inputs.has_key("Grad"):
if opt_op.inputs["Grad"].name in grad_var_names:
optimize_sub_program.global_block().append_op(
type=opt_op.type,
inputs=opt_op.inputs,
outputs=opt_op.outputs,
attrs=opt_op.attrs)
else:
optimize_sub_program.global_block().append_op(
type=opt_op.type,
inputs=opt_op.inputs,
outputs=opt_op.outputs,
attrs=opt_op.attrs)
pserver_program.global_block().append_op(
type="recv",
inputs={"RX":
self.param_grad_map[endpoint]["grads"]}, # grads to recv
outputs={},
attrs={
"OptimizeProgram": optimize_sub_program.desc,
"endpoint": endpoint,
"ParamList":
[p.name for p in self.param_grad_map[endpoint]["params"]],
"GradList":
[p.name for p in self.param_grad_map[endpoint]["grads"]],
"Trainers": self.trainers
})
pserver_program.sync_with_cpp()
return pserver_program
def hash_name(varlist, pserver_endpoints):
"""
hash variable names to several endpoints.
:param varlist: a list of Variables
:return: a map of pserver endpoint -> varname
"""
def _hash_block(block_str, total):
return hash(block_str) % total
eplist = []
for var in varlist:
server_id = _hash_block(var.name(), len(pserver_endpoints))
server_for_param = pserver_endpoints[server_id]
eplist.append(server_for_param)
return eplist
def round_robin(varlist, pserver_endpoints):
"""
distribute variables to several endpoints.
"""
assert (len(varlist) > len(pserver_endpoints))
eplist = []
pserver_idx = 0
for var in varlist:
server_for_param = pserver_endpoints[pserver_idx]
eplist.append(server_for_param)
pserver_idx += 1
if pserver_idx >= len(pserver_endpoints):
pserver_idx = 0
return eplist
......@@ -4,7 +4,10 @@ import layers
from framework import Program, unique_name, Variable, program_guard
from layer_helper import LayerHelper
__all__ = ['Accuracy', 'ChunkEvaluator']
__all__ = [
'Accuracy',
'ChunkEvaluator',
]
def _clone_var_(block, var):
......@@ -21,19 +24,19 @@ def _clone_var_(block, var):
class Evaluator(object):
"""
Base Class for all evaluators
Args:
name(str): The name of evaluator. such as, "accuracy". Used for generate
name(str): The name of evaluator. such as, "accuracy". Used for generate
temporary variable name.
main_program(Program, optional): The evaluator should be added to this
main_program(Program, optional): The evaluator should be added to this
main_program. Default default_main_program()
startup_program(Program, optional):The parameter should be added to this
startup_program(Program, optional):The parameter should be added to this
startup_program. Default default_startup_program()
Attributes:
states(list): The list of state variables. states will be reset to zero
states(list): The list of state variables. states will be reset to zero
when `reset` is invoked.
metrics(list): The list of metrics variables. They will be calculate
metrics(list): The list of metrics variables. They will be calculate
every mini-batch
"""
......@@ -66,14 +69,14 @@ class Evaluator(object):
def create_state(self, suffix, dtype, shape):
"""
Create state variable.
Create state variable.
NOTE: It is not a public API.
Args:
suffix(str): the state suffix.
dtype(str|core.DataType): the state data type
shape(tuple|list): the shape of state
suffix(str): the state suffix.
dtype(str|core.DataType): the state data type
shape(tuple|list): the shape of state
Returns: State variable
......@@ -127,8 +130,8 @@ class Accuracy(Evaluator):
class ChunkEvaluator(Evaluator):
"""
Accumulate counter numbers output by chunk_eval from mini-batches and
compute the precision recall and F1-score using the accumulated counter
Accumulate counter numbers output by chunk_eval from mini-batches and
compute the precision recall and F1-score using the accumulated counter
numbers.
"""
......
......@@ -7,9 +7,15 @@ import proto.framework_pb2 as framework_pb2
from . import core
__all__ = [
'Block', 'Variable', 'Program', 'Operator', 'default_startup_program',
'default_main_program', 'program_guard', 'switch_startup_program',
'switch_main_program'
'Block',
'Variable',
'Program',
'Operator',
'default_startup_program',
'default_main_program',
'program_guard',
'switch_startup_program',
'switch_main_program',
]
EMPTY_VAR_NAME = core.kEmptyVarName()
......@@ -274,6 +280,9 @@ class Variable(object):
uid = core.unique_integer(prefix) # unique during whole process.
return "_".join([prefix, str(uid)])
def set_error_clip(self, error_clip):
self.error_clip = error_clip
def get_all_op_protos():
"""
......
import framework
import numpy as np
__all__ = ['Constant', 'Uniform', 'Normal', 'Xavier']
__all__ = [
'Constant',
'Uniform',
'Normal',
'Xavier',
]
class Initializer(object):
......
......@@ -4,9 +4,15 @@ import cPickle as pickle
from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
__all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', "save_inference_model", "load_inference_model",
"get_inference_program"
'save_vars',
'save_params',
'save_persistables',
'load_vars',
'load_params',
'load_persistables',
'save_inference_model',
'load_inference_model',
'get_inference_program',
]
......
......@@ -117,7 +117,8 @@ def Print(input,
print_tensor_name=True,
print_tensor_type=True,
print_tensor_shape=True,
print_tensor_lod=True):
print_tensor_lod=True,
print_phase='both'):
'''
**Print operator**
......@@ -128,18 +129,21 @@ def Print(input,
tensor `t`.
Args:
input(Variable): A Tensor to print.
summarize(int): Print this number of elements in the tensor, will print all
if left negative.
message(str): A string message to print as a prefix.
first_n(int): Only log `first_n` number of times.
print_tensor_name(bool): Print the tensor name.
print_tensor_type(bool): Print the tensor type.
print_tensor_shape(bool): Print the tensor shape.
print_tensor_lod(bool): Print the tensor lod.
input (Variable): A Tensor to print.
summarize (int): Print this number of elements in the tensor, will print
all if left is negative.
message (str): A string message to print as a prefix.
first_n (int): Only log `first_n` number of times.
print_tensor_name (bool): Print the tensor name.
print_tensor_type (bool): Print the tensor type.
print_tensor_shape (bool): Print the tensor shape.
print_tensor_lod (bool): Print the tensor lod.
print_phase (bool): Which phase to displace, including 'forward',
'backward' and 'both'. If set to 'backward' or 'both', will
print the gradients of input tensor.
Returns:
None
Variable: Output tensor, same data with input tensor.
Examples:
.. code-block:: python
......@@ -149,10 +153,10 @@ def Print(input,
message="The content of some_layer: ")
'''
helper = LayerHelper('print', **locals())
out = helper.create_tmp_variable(dtype='int32')
out = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(
type='print',
inputs={'input': input},
inputs={'In': input},
attrs={
'first_n': first_n,
'summarize': summarize,
......@@ -161,7 +165,9 @@ def Print(input,
'print_tensor_type': print_tensor_type,
'print_tensor_shape': print_tensor_shape,
'print_tensor_lod': print_tensor_lod,
})
'print_phase': print_phase.upper()
},
outputs={'Out': out})
return out
......@@ -742,11 +748,10 @@ def topk(input, k):
def lod_tensor_to_array(x, table):
"""This function performs the operation that converts an LOD_Tensor to
an array.
""" Convert a LOD_TENSOR to an LOD_TENSOR_ARRAY.
Args:
x (Variable|list): The tensor that needs to be converted to an array.
x (Variable|list): The LOD tensor to be converted to a LOD tensor array.
table (ParamAttr|list): The variable that stores the level of lod
which is ordered by sequence length in
descending order.
......@@ -776,11 +781,10 @@ def lod_tensor_to_array(x, table):
def array_to_lod_tensor(x, table):
"""This function performs the operations that converts an array to
an LOD_Tensor.
"""Convert a LoD_Tensor_Aarry to an LoDTensor.
Args:
x (Variable|list): The array that needs to be converted to a tensor.
x (Variable|list): The lod tensor array to be converted to a tensor.
table (ParamAttr|list): The variable that stores the level of lod
which is ordered by sequence length in
descending order.
......@@ -808,7 +812,8 @@ def array_to_lod_tensor(x, table):
def increment(x, value=1.0, in_place=True):
"""This function performs an operation that increments each value in the
"""
This function performs an operation that increments each value in the
input :math:`x` by an amount: :math:`value` as mentioned in the input
parameter. This operation is performed in-place by default.
......@@ -841,17 +846,24 @@ def increment(x, value=1.0, in_place=True):
def array_write(x, i, array=None):
"""This function performs the operation to write the data out as an
LOD_TENSOR_ARRAY.
"""
This function writes the given input variable to the specified position
indicating by the arrary index to an output LOD_TENSOR_ARRAY. If the
output LOD_TENSOR_ARRAY is not given(None), a new one will be created and
returned.
Args:
x (Variable|list): The input tensor from which the data will be read.
i (Variable|list): The subscript index in tensor array, that points the
place from which data will be read.
array (Variable|list): The data can be read into this variable if
this is assigned.
i (Variable|list): The index of the output LOD_TENSOR_ARRAY, pointing to
the position to which the input tensor will be
written.
array (Variable|list): The output LOD_TENSOR_ARRAY to which the input
tensor will be written. If this parameter is
NONE, a new LOD_TENSOR_ARRAY will be created and
returned.
Returns:
Variable: The tensor type variable that has the data written to it.
Variable: The output LOD_TENSOR_ARRAY where the input tensor is written.
Examples:
.. code-block::python
......@@ -1214,7 +1226,8 @@ class DynamicRNN(object):
self.lod_rank_table = None
self.max_seq_len = None
self.step_idx = None
self.zero_idx = fill_constant(shape=[1], value=0, dtype='int64')
self.zero_idx = fill_constant(
shape=[1], value=0, dtype='int64', force_cpu=True)
self.mem_dict = dict()
self.output_array = []
self.outputs = []
......@@ -1228,7 +1241,7 @@ class DynamicRNN(object):
self._assert_in_rnn_block_("step_input")
if not isinstance(x, Variable):
raise TypeError(
"step_input() can only take a Variable as its input")
"step_input() can only take a Variable as its input.")
parent_block = self._parent_block_()
if self.lod_rank_table is None:
self.lod_rank_table = parent_block.create_var(
......@@ -1269,7 +1282,8 @@ class DynamicRNN(object):
def block(self):
if self.status != DynamicRNN.BEFORE_RNN:
raise ValueError("rnn.block() can only be invoke once")
self.step_idx = fill_constant(shape=[1], dtype='int64', value=0)
self.step_idx = fill_constant(
shape=[1], dtype='int64', value=0, force_cpu=True)
self.step_idx.stop_gradient = False
self.status = DynamicRNN.IN_RNN
with self.while_op.block():
......@@ -1289,8 +1303,8 @@ class DynamicRNN(object):
def __call__(self, *args, **kwargs):
if self.status != DynamicRNN.AFTER_RNN:
raise ValueError(
"Dynamic RNN outputs can only be retrieved after rnn block")
raise ValueError(("Output of the dynamic RNN can only be visited "
"outside the rnn block."))
if len(self.outputs) == 1:
return self.outputs[0]
else:
......
......@@ -9,12 +9,33 @@ from ..param_attr import ParamAttr
from tensor import concat
__all__ = [
'fc', 'embedding', 'dynamic_lstm', 'gru_unit', 'linear_chain_crf',
'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy',
'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d',
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand',
'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min',
'sequence_first_step', 'sequence_last_step', 'dropout'
'fc',
'embedding',
'dynamic_lstm',
'gru_unit',
'linear_chain_crf',
'crf_decoding',
'cos_sim',
'cross_entropy',
'square_error_cost',
'accuracy',
'chunk_eval',
'sequence_conv',
'conv2d',
'sequence_pool',
'pool2d',
'batch_norm',
'beam_search_decode',
'conv2d_transpose',
'sequence_expand',
'lstm_unit',
'reduce_sum',
'reduce_mean',
'reduce_max',
'reduce_min',
'sequence_first_step',
'sequence_last_step',
'dropout',
]
......@@ -248,13 +269,13 @@ def gru_unit(input,
h_t & = dot((1-u_t), m_t) + dot(u_t, h_{t-1})
The inputs of gru unit includes :math:`z_t`, :math:`h_{t-1}`. In terms
of the equation above, the :math:`z_t` is split into 3 parts -
:math:`xu_t`, :math:`xr_t` and :math:`xm_t`. This means that in order to
implement a full GRU unit operator for an input, a fully
of the equation above, the :math:`z_t` is split into 3 parts -
:math:`xu_t`, :math:`xr_t` and :math:`xm_t`. This means that in order to
implement a full GRU unit operator for an input, a fully
connected layer has to be applied, such that :math:`z_t = W_{fc}x_t`.
The terms :math:`u_t` and :math:`r_t` represent the update and reset gates
of the GRU cell. Unlike LSTM, GRU has one lesser gate. However, there is
The terms :math:`u_t` and :math:`r_t` represent the update and reset gates
of the GRU cell. Unlike LSTM, GRU has one lesser gate. However, there is
an intermediate candidate hidden output, which is denoted by :math:`m_t`.
This layer has three outputs :math:`h_t`, :math:`dot(r_t, h_{t-1})`
and concatenation of :math:`u_t`, :math:`r_t` and :math:`m_t`.
......@@ -276,7 +297,7 @@ def gru_unit(input,
.. code-block:: python
# assuming we have x_t_data and prev_hidden of size=10
x_t = fluid.layers.fc(input=x_t_data, size=30)
x_t = fluid.layers.fc(input=x_t_data, size=30)
hidden_val, r_h_val, gate_val = fluid.layers.gru_unit(input=x_t,
hidden = prev_hidden)
......@@ -754,7 +775,7 @@ def conv2d(input,
pre_bias = helper.create_tmp_variable(dtype)
helper.append_op(
type='conv2d_cudnn',
type='conv2d',
inputs={
'Input': input,
'Filter': filter_param,
......
from ..registry import register_layer
__activations__ = [
'abs', 'tanh', 'sigmoid', 'relu', 'sqrt', 'ceil', 'floor', 'log', 'round'
'sigmoid',
'logsigmoid',
'exp',
'relu',
'tanh',
'tanh_shrink',
'softshrink',
'sqrt',
'abs',
'ceil',
'floor',
'round',
'reciprocal',
'log',
'square',
'softplus',
'softsign',
'brelu',
'leaky_relu',
'soft_relu',
'elu',
'relu6',
'pow',
'stanh',
'hard_shrink',
'thresholded_relu',
'hard_sigmoid',
'swish',
]
__all__ = [
......
......@@ -6,8 +6,16 @@ from ..core import DataType
import numpy
__all__ = [
'create_tensor', 'create_parameter', 'cast', 'concat', 'sums', 'assign',
'fill_constant_batch_size_like', 'fill_constant', 'ones', 'zeros'
'create_tensor',
'create_parameter',
'cast',
'concat',
'sums',
'assign',
'fill_constant_batch_size_like',
'fill_constant',
'ones',
'zeros',
]
......@@ -172,29 +180,30 @@ def assign(input, output):
return output
def fill_constant(shape, dtype, value, out=None):
def fill_constant(shape, dtype, value, force_cpu=False, out=None):
"""
**fill_constant**
This function creates a tensor of specified *shape* and
*dtype*, and initializes this with a constant supplied in *value*.
This function creates a tensor with specified `shape` and `dtype`, and
initializes it with a constant specifed by `value`.
It also sets *stop_gradient* to True.
The attribute `stop_gradient` of the created tensor is set to True.
Args:
shape(tuple|list|None): Shape of output tensor
dtype(np.dtype|core.DataType|str): Data type of output tensor
value(float): Constant value to initialize the output tensor
out(Variable): Output Variable to initialize
shape(tuple|list|None): Shape of the output tensor.
dtype(np.dtype|core.DataType|str): Data type of the output tensor.
value(float): The constant value used to initialize the output tensor.
out(Variable): The output tensor.
Returns:
Variable: The tensor variable storing the output
Variable: The tensor variable storing the output.
Examples:
.. code-block:: python
data = fluid.layers.fill_constant(shape=[1], value=0, dtype='int64')
"""
helper = LayerHelper("fill_constant", **locals())
if out is None:
out = helper.create_tmp_variable(dtype=dtype)
......@@ -202,9 +211,12 @@ def fill_constant(shape, dtype, value, out=None):
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={'shape': shape,
'dtype': out.dtype,
'value': float(value)})
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu
})
out.stop_gradient = True
return out
......
......@@ -121,8 +121,10 @@ class ControlFlowGraph(object):
# and dtype_to_size[cache_dtype]
if x_dtype == cache_dtype:
print(
"Hit Cache !!!! cache pool index is %d, var name is %s, cached var name is %s, var shape is %s "
%
("Hit Cache !!!! cache pool index "
"is %d, var name is %s, "
"cached var name is %s, "
"var shape is %s ") %
(index, x, cache_var, str(cache_shape)))
self.pool.pop(index)
_rename_arg_(
......
import layers
__all__ = ["simple_img_conv_pool", "sequence_conv_pool"]
__all__ = [
"simple_img_conv_pool",
"sequence_conv_pool",
]
def simple_img_conv_pool(input,
......
......@@ -8,7 +8,11 @@ import proto.framework_pb2 as framework_pb2
from framework import OpProtoHolder, Variable, Program, Operator
from paddle.v2.fluid.layer_helper import LayerHelper, unique_name
__all__ = ['deprecated', 'register_layer', 'autodoc']
__all__ = [
'deprecated',
'register_layer',
'autodoc',
]
def _convert_(name):
......@@ -80,11 +84,10 @@ def _generate_doc_string_(op_proto):
def register_layer(op_type):
"""
Register an Python layer for an Operator
"""Register the Python layer for an Operator.
Args:
op_type: The name of the operator to be created
op_type: The name of the operator to be created.
This function takes in the operator type (sigmoid, mean , average etc) and
creates the operator functionality.
......@@ -98,16 +101,16 @@ def register_layer(op_type):
if len(not_intermediate_outputs) != 1:
raise ValueError("Only one non intermediate output operator can be",
"automatically generated")
"automatically generated.")
if not_intermediate_outputs[0].duplicable:
raise ValueError(
"Only non duplicable op can be automatically generated")
"Only non duplicable op can be automatically generated.")
for output in intermediate_outputs:
if output.duplicable:
raise ValueError("The op can be automatically generated only when ",
"all intermediate ops are not duplicable")
"all intermediate ops are not duplicable.")
o_name = not_intermediate_outputs[0].name
intermediate_output_names = [output.name for output in intermediate_outputs]
......
import framework
__all__ = ['append_regularization_ops', 'L1Decay', 'L2Decay']
__all__ = [
'append_regularization_ops',
'L1Decay',
'L2Decay',
]
def append_regularization_ops(parameters_and_grads, regularization=None):
......
......@@ -5,3 +5,4 @@ foreach(src ${TEST_OPS})
endforeach()
add_subdirectory(book)
add_subdirectory(book_distribute)
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
import math
import unittest
from paddle.v2.fluid.distribute_transpiler import split_dense_variable
import paddle.v2.fluid as fluid
import paddle.v2.fluid.core as core
import random
class TestSplitVar(unittest.TestCase):
def test_check_output(self):
# split below shapes to 10 servers
shapes = [[3, 5], [1024], [28, 784], [8, 1020], [800, 10]]
expected_sizes = [
[15], [1024],
[2352, 2352, 2352, 2352, 2352, 2352, 2352, 2352, 2352, 784],
[2040, 2040, 2040, 2040],
[1150, 1150, 1150, 1150, 1150, 1150, 1100]
]
var_list = []
program = fluid.Program()
for shape in shapes:
var = program.global_block().create_var(
name=str(random.randint(10000, 99999)),
persistable=True,
# dtype=core.VarDesc.VarType.LOD_TENSOR,
shape=shape)
var_list.append(var)
blocks = split_dense_variable(var_list, 10)
all_sizes = []
for s in expected_sizes:
for s2 in s:
all_sizes.append(s2)
for i, block_str in enumerate(blocks):
varname, block_id, size = block_str.split(":")
self.assertEqual(int(size), all_sizes[i])
if __name__ == '__main__':
unittest.main()
......@@ -31,7 +31,8 @@ def create_op(scope, op_type, inputs, outputs, attrs):
kwargs[in_name] = []
if in_dup:
sub_in = inputs[in_name]
for sub_in_name, _ in sub_in:
for item in sub_in:
sub_in_name, _ = item[0], item[1]
__create_var__(in_name, sub_in_name)
else:
__create_var__(in_name, in_name)
......@@ -41,7 +42,8 @@ def create_op(scope, op_type, inputs, outputs, attrs):
kwargs[out_name] = []
if out_dup:
sub_out = outputs[out_name]
for sub_out_name, _ in sub_out:
for item in sub_out:
sub_out_name, _ = item[0], item[1]
__create_var__(out_name, sub_out_name)
else:
__create_var__(out_name, out_name)
......@@ -71,13 +73,15 @@ def set_input(scope, op, inputs, place):
if in_name in inputs:
if in_dup:
sub_in = inputs[in_name]
for sub_in_name, sub_in_val in sub_in:
for item in sub_in:
sub_in_name, sub_in_val = item[0], item[1]
__set_input__(sub_in_name, sub_in_val)
else:
__set_input__(in_name, inputs[in_name])
def get_numeric_gradient(scope,
def get_numeric_gradient(place,
scope,
op,
inputs,
input_to_check,
......@@ -85,7 +89,7 @@ def get_numeric_gradient(scope,
delta=0.005,
in_place=False):
# FIXME: change this method by compile time concepts
set_input(scope, op, inputs, core.CPUPlace())
set_input(scope, op, inputs, place)
def product(dim):
return reduce(lambda a, b: a * b, dim, 1)
......@@ -93,7 +97,7 @@ def get_numeric_gradient(scope,
def get_output():
sum = []
for output_name in output_names:
op.run(scope, core.CPUPlace())
op.run(scope, place)
sum.append(
np.array(scope.find_var(output_name).get_tensor()).mean())
return np.array(sum).mean()
......@@ -127,7 +131,7 @@ def get_numeric_gradient(scope,
# we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size):
if in_place:
set_input(scope, op, inputs, core.CPUPlace())
set_input(scope, op, inputs, place)
# get one input element throw it's index i.
origin = __get_elem__(tensor_to_check, i)
......@@ -137,7 +141,7 @@ def get_numeric_gradient(scope,
y_pos = get_output()
if in_place:
set_input(scope, op, inputs, core.CPUPlace())
set_input(scope, op, inputs, place)
x_neg = origin - delta
__set_elem__(tensor_to_check, i, x_neg)
......@@ -283,7 +287,8 @@ class OpTest(unittest.TestCase):
if not isinstance(sub_out, list):
raise AssertionError("sub_out type %s is not list",
type(sub_out))
for sub_out_name, expect in sub_out:
for item in sub_out:
sub_out_name, expect = item[0], item[1]
idx = find_actual(sub_out_name, fetch_list)
actual = outs[idx]
actual_t = np.array(actual)
......@@ -347,6 +352,24 @@ class OpTest(unittest.TestCase):
in_place=False,
max_relative_error=0.005,
user_defined_grads=None):
places = [core.CPUPlace()]
if core.is_compile_gpu() and core.op_support_gpu(self.op_type):
places.append(core.CUDAPlace(0))
for place in places:
self.check_grad_with_place(place, inputs_to_check, output_names,
no_grad_set, numeric_grad_delta,
in_place, max_relative_error,
user_defined_grads)
def check_grad_with_place(self,
place,
inputs_to_check,
output_names,
no_grad_set=None,
numeric_grad_delta=0.005,
in_place=False,
max_relative_error=0.005,
user_defined_grads=None):
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
......@@ -362,6 +385,7 @@ class OpTest(unittest.TestCase):
numeric_grads = user_defined_grads or [
get_numeric_gradient(
place,
self.scope,
self.op,
self.inputs,
......@@ -370,22 +394,12 @@ class OpTest(unittest.TestCase):
delta=numeric_grad_delta,
in_place=in_place) for input_to_check in inputs_to_check
]
cpu_place = core.CPUPlace()
cpu_analytic_grads = self._get_gradient(inputs_to_check, cpu_place,
output_names, no_grad_set)
self.__assert_is_close(numeric_grads, cpu_analytic_grads,
inputs_to_check, max_relative_error,
"Gradient Check On %s" % str(cpu_place))
if core.is_compile_gpu() and self.op.support_gpu():
gpu_place = core.CUDAPlace(0)
gpu_analytic_grads = self._get_gradient(inputs_to_check, gpu_place,
output_names, no_grad_set)
self.__assert_is_close(numeric_grads, gpu_analytic_grads,
inputs_to_check, max_relative_error,
"Gradient Check On %s" % str(gpu_place))
analytic_grads = self._get_gradient(inputs_to_check, place,
output_names, no_grad_set)
self.__assert_is_close(numeric_grads, analytic_grads, inputs_to_check,
max_relative_error,
"Gradient Check On %s" % str(place))
@staticmethod
def _create_var_descs_(block, var_dict):
......
from __future__ import print_function
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
BATCH_SIZE = 128
CLIP_MAX = 2e-6
CLIP_MIN = -1e-6
prog = fluid.framework.Program()
with fluid.program_guard(main_program=prog):
image = fluid.layers.data(name='x', shape=[784], dtype='float32')
hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
predict = fluid.layers.fc(input=hidden2, size=10, act='softmax')
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
prog_clip = prog.clone()
prog_clip.block(0).var(hidden1.name).set_error_clip(
fluid.clip.ErrorClipByValue(
max=CLIP_MAX, min=CLIP_MIN))
avg_cost_clip = prog_clip.block(0).var(avg_cost.name)
fluid.backward.append_backward(loss=avg_cost)
fluid.backward.append_backward(
loss=avg_cost_clip, callback=fluid.clip.error_clip_callback)
hidden1_grad = prog.block(0).var(hidden1.name + "@GRAD")
hidden1_grad_clip = prog_clip.block(0).var(hidden1.name + "@GRAD")
hidden2_grad = prog.block(0).var(hidden2.name + "@GRAD")
hidden2_grad_clip = prog_clip.block(0).var(hidden2.name + "@GRAD")
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=BATCH_SIZE)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
exe.run(fluid.default_startup_program())
count = 0
for data in train_reader():
count += 1
if count > 5:
break
out1, out2 = exe.run(prog,
feed=feeder.feed(data),
fetch_list=[hidden1_grad, hidden2_grad])
out1_clip, out2_clip = exe.run(
prog_clip,
feed=feeder.feed(data),
fetch_list=[hidden1_grad_clip, hidden2_grad_clip])
if not ((out1.clip(
min=CLIP_MIN, max=CLIP_MAX) == out1_clip).all() and
(out2 == out2_clip).all()):
exit(1)
exit(0)
......@@ -49,7 +49,7 @@ def conv2d_forward_naive(input, filter, group, conv_param):
class TestConv2dOp(OpTest):
def setUp(self):
core.use_cuda()
self.use_cudnn = False
self.init_op_type()
self.init_group()
self.init_dilation()
......@@ -70,30 +70,59 @@ class TestConv2dOp(OpTest):
'strides': self.stride,
'paddings': self.pad,
'groups': self.groups,
'dilations': self.dilations
'dilations': self.dilations,
'use_cudnn': self.use_cudnn
}
self.outputs = {'Output': output}
def test_check_output(self):
self.check_output()
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output()
def test_check_grad(self):
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
set(['Input', 'Filter']),
'Output',
max_relative_error=0.02)
else:
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
def test_check_grad_no_filter(self):
self.check_grad(
['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
else:
self.check_grad(
['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
else:
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
def init_test_case(self):
self.pad = [0, 0]
......@@ -167,39 +196,39 @@ class TestWithDilation(TestConv2dOp):
self.groups = 3
#----------------Conv2dCudnn----------------
class TestCudnn(TestConv2dOp):
#----------------Conv2dCUDNN----------------
class TestCUDNN(TestConv2dOp):
def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn"
self.use_cudnn = True
self.op_type = "conv2d"
class TestCudnnWithPad(TestWithPad):
class TestCUDNNWithPad(TestWithPad):
def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn"
self.use_cudnn = True
self.op_type = "conv2d"
class TestCudnnWithStride(TestWithStride):
class TestCUDNNWithStride(TestWithStride):
def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn"
self.use_cudnn = True
self.op_type = "conv2d"
class TestCudnnWithGroup(TestWithGroup):
class TestCUDNNWithGroup(TestWithGroup):
def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn"
self.use_cudnn = True
self.op_type = "conv2d"
class TestCudnnWith1x1(TestWith1x1):
class TestCUDNNWith1x1(TestWith1x1):
def init_op_type(self):
core.use_cudnn()
self.op_type = "conv2d_cudnn"
self.use_cudnn = True
self.op_type = "conv2d"
# cudnn v5 does not support dilation conv.
# class TestCudnnWithDilation(TestWithDilation):
# class TestCUDNNWithDilation(TestWithDilation):
# def init_op_type(self):
# self.op_type = "conv_cudnn"
......
import unittest
import numpy as np
import paddle.v2.fluid.core as core
from op_test import OpTest
......@@ -37,6 +39,7 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs):
class TestConv2dTransposeOp(OpTest):
def setUp(self):
# init as conv transpose
self.use_cudnn = False
self.init_op_type()
self.init_test_case()
......@@ -47,7 +50,9 @@ class TestConv2dTransposeOp(OpTest):
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'dilations': self.dilations
'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
}
output = conv2dtranspose_forward_naive(input_, filter_,
......@@ -56,25 +61,53 @@ class TestConv2dTransposeOp(OpTest):
self.outputs = {'Output': output}
def test_check_output(self):
self.check_output()
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output()
def test_check_grad_no_input(self):
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
else:
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
def test_check_grad_no_filter(self):
self.check_grad(
['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
else:
self.check_grad(
['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
def test_check_grad(self):
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
set(['Input', 'Filter']),
'Output',
max_relative_error=0.02)
else:
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
def init_test_case(self):
self.pad = [0, 0]
......@@ -119,12 +152,13 @@ class TestWithDilation(TestConv2dTransposeOp):
# ------------ test_cudnn ------------
class TestCudnn(TestConv2dTransposeOp):
class TestCUDNN(TestConv2dTransposeOp):
def init_op_type(self):
self.op_type = "conv2d_transpose_cudnn"
self.use_cudnn = True
self.op_type = "conv2d_transpose"
class TestCudnnWithPad(TestWithPad):
class TestCUDNNWithPad(TestWithPad):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
......@@ -134,10 +168,11 @@ class TestCudnnWithPad(TestWithPad):
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.op_type = "conv2d_transpose_cudnn"
self.use_cudnn = True
self.op_type = "conv2d_transpose"
class TestCudnnWithStride(TestWithStride):
class TestCUDNNWithStride(TestWithStride):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
......@@ -147,11 +182,12 @@ class TestCudnnWithStride(TestWithStride):
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.op_type = "conv2d_transpose_cudnn"
self.use_cudnn = True
self.op_type = "conv2d_transpose"
# #cudnn v5 does not support dilation conv.
# class TestCudnnWithDilation(TestWithDilation):
# class TestCUDNNWithDilation(TestWithDilation):
# def init_test_case(self):
# self.pad = [1, 1]
# self.stride = [2, 2]
......@@ -161,7 +197,7 @@ class TestCudnnWithStride(TestWithStride):
# self.filter_size = [f_c, 6, 3, 3]
#
# def init_op_type(self):
# self.op_type = "conv2d_transpose_cudnn"
# self.op_type = "conv2d_transpose"
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
import paddle.v2.fluid.core as core
from op_test import OpTest
......@@ -54,6 +56,7 @@ def conv3d_forward_naive(input, filter, group, conv_param):
class TestConv3dOp(OpTest):
def setUp(self):
self.use_cudnn = False
self.init_group()
self.init_op_type()
self.init_dilation()
......@@ -62,7 +65,9 @@ class TestConv3dOp(OpTest):
conv3d_param = {
'stride': self.stride,
'pad': self.pad,
'dilations': self.dilations
'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
}
input = np.random.random(self.input_size).astype("float32")
filter = np.random.random(self.filter_size).astype("float32")
......@@ -79,25 +84,53 @@ class TestConv3dOp(OpTest):
self.outputs = {'Output': output}
def test_check_output(self):
self.check_output()
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output()
def test_check_grad(self):
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.03)
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
set(['Input', 'Filter']),
'Output',
max_relative_error=0.03)
else:
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.03)
def test_check_grad_no_filter(self):
self.check_grad(
['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']))
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']))
else:
self.check_grad(
['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']))
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']))
else:
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']))
def init_test_case(self):
self.pad = [0, 0, 0]
......@@ -169,31 +202,35 @@ class TestWithDilation(TestConv3dOp):
self.groups = 3
class TestCudnn(TestConv3dOp):
class TestCUDNN(TestConv3dOp):
def init_op_type(self):
self.op_type = "conv3d_cudnn"
self.use_cudnn = True
self.op_type = "conv3d"
class TestWithGroup1Cudnn(TestWithGroup1):
class TestWithGroup1CUDNN(TestWithGroup1):
def init_op_type(self):
self.op_type = "conv3d_cudnn"
self.use_cudnn = True
self.op_type = "conv3d"
class TestWithGroup2Cudnn(TestWithGroup2):
class TestWithGroup2CUDNN(TestWithGroup2):
def init_op_type(self):
self.op_type = "conv3d_cudnn"
self.use_cudnn = True
self.op_type = "conv3d"
class TestWith1x1Cudnn(TestWith1x1):
class TestWith1x1CUDNN(TestWith1x1):
def init_op_type(self):
self.op_type = "conv3d_cudnn"
self.use_cudnn = True
self.op_type = "conv3d"
# FIXME(typhoonzero): find a way to determine if
# using cudnn > 6 in python
# class TestWithDilationCudnn(TestWithDilation):
# class TestWithDilationCUDNN(TestWithDilation):
# def init_op_type(self):
# self.op_type = "conv3d_cudnn"
# self.op_type = "conv3d"
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
import paddle.v2.fluid.core as core
from op_test import OpTest
......@@ -44,6 +46,7 @@ def conv3dtranspose_forward_naive(input_, filter_, attrs):
class TestConv3dTransposeOp(OpTest):
def setUp(self):
# init as conv transpose
self.use_cudnn = False
self.init_op_type()
self.init_test_case()
......@@ -54,7 +57,9 @@ class TestConv3dTransposeOp(OpTest):
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'dilations': self.dilations
'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
}
output = conv3dtranspose_forward_naive(input_, filter_,
......@@ -63,25 +68,53 @@ class TestConv3dTransposeOp(OpTest):
self.outputs = {'Output': output}
def test_check_output(self):
self.check_output()
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output()
def test_check_grad(self):
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
set(['Input', 'Filter']),
'Output',
max_relative_error=0.03)
else:
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.03)
def test_check_grad_no_filter(self):
self.check_grad(
['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']))
else:
self.check_grad(
['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']))
else:
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']))
def init_test_case(self):
self.pad = [0, 0, 0]
......@@ -126,12 +159,13 @@ class TestWithDilation(TestConv3dTransposeOp):
# ------------ test_cudnn ------------
class TestCudnn(TestConv3dTransposeOp):
class TestCUDNN(TestConv3dTransposeOp):
def init_op_type(self):
self.op_type = "conv3d_transpose_cudnn"
self.use_cudnn = True
self.op_type = "conv3d_transpose"
class TestCudnnWithPad(TestWithPad):
class TestCUDNNWithPad(TestWithPad):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [1, 1, 1]
......@@ -141,10 +175,11 @@ class TestCudnnWithPad(TestWithPad):
self.filter_size = [f_c, 6, 3, 3, 3]
def init_op_type(self):
self.op_type = "conv3d_transpose_cudnn"
self.use_cudnn = True
self.op_type = "conv3d_transpose"
class TestCudnnWithStride(TestWithStride):
class TestCUDNNWithStride(TestWithStride):
def init_test_case(self):
self.pad = [1, 1, 1]
self.stride = [2, 2, 2]
......@@ -154,11 +189,12 @@ class TestCudnnWithStride(TestWithStride):
self.filter_size = [f_c, 6, 3, 3, 3]
def init_op_type(self):
self.op_type = "conv3d_transpose_cudnn"
self.use_cudnn = True
self.op_type = "conv3d_transpose"
# #cudnn v5 does not support dilation conv.
# class TestCudnnWithDilation(TestWithDilation):
# class TestCUDNNWithDilation(TestWithDilation):
# def init_test_case(self):
# self.pad = [1, 1, 1]
# self.stride = [2, 2, 2]
......@@ -168,7 +204,7 @@ class TestCudnnWithStride(TestWithStride):
# self.filter_size = [f_c, 6, 3, 3, 3]
#
# def init_op_type(self):
# self.op_type = "conv3d_transpose_cudnn"
# self.op_type = "conv3d_transpose"
if __name__ == '__main__':
unittest.main()
import unittest
import paddle.v2.fluid as fluid
import numpy
import sys
# TODO(dzhwinter): get places op check need to be enhanced.
sys.exit(0)
class BaseParallelForTest(unittest.TestCase):
......@@ -13,13 +17,13 @@ class BaseParallelForTest(unittest.TestCase):
returns the data layers, and the second yield returns the loss.
The modified data variables will be sent back during the first
yield.
feed(dict): The executor feeding dictionary.
fetch(list|basestr): The fetch name lists.
Returns:
None
Raises:
AssertionError when the computation of cpu, parallel.for in cpu,
gpu, parallel.for in gpu are different.
......
import unittest
import numpy as np
import paddle.v2.fluid.core as core
from op_test import OpTest
......@@ -44,6 +46,7 @@ def avg_pool2D_forward_naive(x, ksize, strides, paddings, global_pool=0):
class TestPool2d_Op(OpTest):
def setUp(self):
self.use_cudnn = False
self.init_test_case()
self.init_global_pool()
self.init_op_type()
......@@ -62,15 +65,25 @@ class TestPool2d_Op(OpTest):
'ksize': self.ksize,
'pooling_type': self.pool_type,
'global_pooling': self.global_pool,
'use_cudnn': self.use_cudnn,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
}
self.outputs = {'Out': output.astype('float32')}
def test_check_output(self):
self.check_output()
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output()
def test_check_grad(self):
if self.pool_type != "max":
if self.use_cudnn and self.pool_type != "max":
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, set(['X']), 'Out', max_relative_error=0.07)
elif self.pool_type != "max":
self.check_grad(set(['X']), 'Out', max_relative_error=0.07)
def init_test_case(self):
......@@ -153,35 +166,41 @@ class TestCase5(TestCase2):
self.pool2D_forward_naive = max_pool2D_forward_naive
#--------------------test pool2d_cudnn--------------------
class TestCudnnCase1(TestPool2d_Op):
#--------------------test pool2d--------------------
class TestCUDNNCase1(TestPool2d_Op):
def init_op_type(self):
self.op_type = "pool2d_cudnn"
self.use_cudnn = True
self.op_type = "pool2d"
class TestCudnnCase2(TestCase1):
class TestCUDNNCase2(TestCase1):
def init_op_type(self):
self.op_type = "pool2d_cudnn"
self.use_cudnn = True
self.op_type = "pool2d"
class TestCudnnCase3(TestCase2):
class TestCUDNNCase3(TestCase2):
def init_op_type(self):
self.op_type = "pool2d_cudnn"
self.use_cudnn = True
self.op_type = "pool2d"
class TestCudnnCase4(TestCase3):
class TestCUDNNCase4(TestCase3):
def init_op_type(self):
self.op_type = "pool2d_cudnn"
self.use_cudnn = True
self.op_type = "pool2d"
class TestCudnnCase5(TestCase4):
class TestCUDNNCase5(TestCase4):
def init_op_type(self):
self.op_type = "pool2d_cudnn"
self.use_cudnn = True
self.op_type = "pool2d"
class TestCudnnCase6(TestCase5):
class TestCUDNNCase6(TestCase5):
def init_op_type(self):
self.op_type = "pool2d_cudnn"
self.use_cudnn = True
self.op_type = "pool2d"
if __name__ == '__main__':
......
import unittest
import numpy as np
import paddle.v2.fluid.core as core
from op_test import OpTest
......@@ -52,6 +54,7 @@ def avg_pool3D_forward_naive(x, ksize, strides, paddings, global_pool=0):
class TestPool3d_Op(OpTest):
def setUp(self):
self.use_cudnn = False
self.init_test_case()
self.init_global_pool()
self.init_op_type()
......@@ -71,15 +74,25 @@ class TestPool3d_Op(OpTest):
'ksize': self.ksize,
'pooling_type': self.pool_type,
'global_pooling': self.global_pool,
'use_cudnn': self.use_cudnn,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
}
self.outputs = {'Out': output.astype('float32')}
def test_check_output(self):
self.check_output()
if self.use_cudnn:
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output()
def test_check_grad(self):
if self.pool_type != "max":
if self.use_cudnn and self.pool_type != "max":
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, set(['X']), 'Out', max_relative_error=0.07)
elif self.pool_type != "max":
self.check_grad(set(['X']), 'Out', max_relative_error=0.07)
def init_test_case(self):
......@@ -163,35 +176,41 @@ class TestCase5(TestCase2):
self.pool3D_forward_naive = max_pool3D_forward_naive
#--------------------test pool3d_cudnn--------------------
class TestCudnnCase1(TestPool3d_Op):
#--------------------test pool3d--------------------
class TestCUDNNCase1(TestPool3d_Op):
def init_op_type(self):
self.op_type = "pool3d_cudnn"
self.use_cudnn = True
self.op_type = "pool3d"
class TestCudnnCase2(TestCase1):
class TestCUDNNCase2(TestCase1):
def init_op_type(self):
self.op_type = "pool3d_cudnn"
self.use_cudnn = True
self.op_type = "pool3d"
class TestCudnnCase3(TestCase2):
class TestCUDNNCase3(TestCase2):
def init_op_type(self):
self.op_type = "pool3d_cudnn"
self.use_cudnn = True
self.op_type = "pool3d"
class TestCudnnCase4(TestCase3):
class TestCUDNNCase4(TestCase3):
def init_op_type(self):
self.op_type = "pool3d_cudnn"
self.use_cudnn = True
self.op_type = "pool3d"
class TestCudnnCase5(TestCase4):
class TestCUDNNCase5(TestCase4):
def init_op_type(self):
self.op_type = "pool3d_cudnn"
self.use_cudnn = True
self.op_type = "pool3d"
class TestCudnnCase6(TestCase5):
class TestCUDNNCase6(TestCase5):
def init_op_type(self):
self.op_type = "pool3d_cudnn"
self.use_cudnn = True
self.op_type = "pool3d"
if __name__ == '__main__':
......
import unittest
import numpy as np
from paddle.v2.fluid.executor import Executor
import paddle.v2.fluid.core as core
import paddle.v2.fluid.layers as pd
from paddle.v2.fluid.executor import Executor
import paddle.v2.fluid.layers as layers
from paddle.v2.fluid.backward import append_backward
from paddle.v2.fluid.framework import switch_main_program
from paddle.v2.fluid.framework import Program
import numpy as np
class TestPrintOpCPU(unittest.TestCase):
def setUp(self):
self.place = core.CPUPlace()
self.x_tensor = core.LoDTensor()
tensor_np = np.random.random(size=(2, 3)).astype('float32')
self.x_tensor.set(tensor_np, self.place)
self.x_tensor.set_lod([[0, 1, 1]])
def build_network(self, only_forward, **kargs):
x = layers.data('x', shape=[3], dtype='float32', lod_level=1)
x.stop_gradient = False
printed = layers.Print(input=x, **kargs)
if only_forward: return printed
loss = layers.mean(x=printed)
append_backward(loss=loss)
return loss
class TestSumOp(unittest.TestCase):
def test_tensor(self):
i = pd.zeros(shape=[2, 10], dtype='float32')
def test_forward(self):
switch_main_program(Program())
printed = self.build_network(True, print_phase='forward')
exe = Executor(self.place)
outs = exe.run(feed={'x': self.x_tensor},
fetch_list=[printed],
return_numpy=False)
pd.Print(i, message="I am a message", summarize=10)
def test_backward(self):
switch_main_program(Program())
loss = self.build_network(False, print_phase='backward')
exe = Executor(self.place)
outs = exe.run(feed={'x': self.x_tensor},
fetch_list=[loss],
return_numpy=False)
cpu = core.CPUPlace()
exe = Executor(cpu)
exe.run()
class TestPrintOpGPU(TestPrintOpCPU):
def setUp(self):
self.place = core.CUDAPlace(0)
self.x_tensor = core.LoDTensor()
tensor_np = np.random.random(size=(2, 3)).astype('float32')
self.x_tensor.set(tensor_np, self.place)
self.x_tensor.set_lod([[0, 1, 1]])
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册