提交 aca05d59 编写于 作者: S shippingwang

Merge branch 'release/1.0.0' of https://github.com/PaddlePaddle/Paddle into release/1.0.0

......@@ -61,12 +61,12 @@ paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None
paddle.fluid.layers.cross_entropy ArgSpec(args=['input', 'label', 'soft_label', 'ignore_index'], varargs=None, keywords=None, defaults=(False, -100))
paddle.fluid.layers.square_error_cost ArgSpec(args=['input', 'label'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.chunk_eval ArgSpec(args=['input', 'label', 'chunk_scheme', 'num_chunk_types', 'excluded_chunk_types'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None))
paddle.fluid.layers.sequence_conv ArgSpec(args=['input', 'num_filters', 'filter_size', 'filter_stride', 'padding', 'bias_attr', 'param_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(3, 1, None, None, None, None, None))
paddle.fluid.layers.conv2d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None))
paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, True, None, None))
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn'], varargs=None, keywords=None, defaults=(None, None, False))
paddle.fluid.layers.softmax ArgSpec(args=['input', 'param_attr', 'bias_attr', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(None, None, True, None))
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False))
......@@ -95,8 +95,8 @@ paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_ti
paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
......
......@@ -19,8 +19,18 @@ cc_library(paddle_fluid_origin DEPS ${fluid_modules} paddle_fluid_api)
add_subdirectory(api)
set(STATIC_INFERENCE_APIS paddle_fluid_api paddle_inference_api analysis_predictor)
set(SHARED_INFERENCE_SRCS
io.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc
${CMAKE_CURRENT_SOURCE_DIR}/api/analysis_predictor.cc )
if (WITH_GPU AND TENSORRT_FOUND)
set(STATIC_INFERENCE_APIS ${STATIC_INFERENCE_APIS} paddle_inference_tensorrt_subgraph_engine)
set(SHARED_INFERENCE_SRCS ${SHARED_INFERENCE_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/api/api_tensorrt_subgraph_engine.cc)
endif()
# Create static library
cc_library(paddle_fluid DEPS ${fluid_modules} paddle_fluid_api paddle_inference_api analysis_predictor)
cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} )
if(NOT APPLE)
# TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac.
set(LINK_FLAGS "-Wl,--retain-symbols-file ${CMAKE_CURRENT_SOURCE_DIR}/paddle_fluid.sym")
......@@ -28,9 +38,7 @@ if(NOT APPLE)
endif()
# Create shared library
cc_library(paddle_fluid_shared SHARED
SRCS io.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc
${CMAKE_CURRENT_SOURCE_DIR}/api/analysis_predictor.cc
cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS}
DEPS ${fluid_modules} paddle_fluid_api)
set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid)
......
......@@ -3,6 +3,7 @@ project(cpp_inference_demo CXX C)
option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON)
option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF)
option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON)
option(USE_TENSORRT "Compile demo with TensorRT." OFF)
macro(safe_set_static_flag)
foreach(flag_var
......@@ -60,6 +61,13 @@ endif(NOT WIN32)
include_directories("${PADDLE_LIB}/third_party/boost")
include_directories("${PADDLE_LIB}/third_party/eigen3")
if (NOT WIN32)
if (USE_TENSORRT AND WITH_GPU)
include_directories("${TENSORRT_INCLUDE_DIR}")
link_directories("${TENSORRT_LIB_DIR}")
endif()
endif(NOT WIN32)
if (NOT WIN32)
link_directories("${PADDLE_LIB}/third_party/install/snappy/lib")
link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib")
......@@ -112,6 +120,10 @@ endif(NOT WIN32)
if(WITH_GPU)
if(NOT WIN32)
if (USE_TENSORRT)
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer${CMAKE_STATIC_LIBRARY_SUFFIX})
set(DEPS ${DEPS} ${TENSORRT_LIB_DIR}/libnvinfer_plugin${CMAKE_STATIC_LIBRARY_SUFFIX})
endif()
set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX})
else()
set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} )
......
......@@ -2,6 +2,12 @@ set -x
PADDLE_ROOT=$1
TURN_ON_MKL=$2 # use MKL or Openblas
TEST_GPU_CPU=$3 # test both GPU/CPU mode or only CPU mode
DATA_DIR=$4 # dataset
TENSORRT_INCLUDE_DIR=$5 # TensorRT header file dir, defalut to /usr/local/TensorRT/include
TENSORRT_LIB_DIR=$6 # TensorRT lib file dir, default to /usr/local/TensorRT/lib
cd `dirname $0`
current_dir=`pwd`
if [ $2 == ON ]; then
# You can export yourself if move the install path
MKL_LIB=${PADDLE_ROOT}/build/fluid_install_dir/third_party/install/mklml/lib
......@@ -13,6 +19,11 @@ else
use_gpu_list='false'
fi
USE_TENSORRT=OFF
if [ [-d"$TENSORRT_INCLUDE_DIR"] -a [-d"$TENSORRT_LIB_DIR"] ]; then
USE_TENSORRT=ON
fi
PREFIX=inference-vis-demos%2F
URL_ROOT=http://paddlemodels.cdn.bcebos.com/${PREFIX}
......@@ -29,15 +40,15 @@ function download() {
fi
cd ..
}
mkdir -p data
cd data
mkdir -p $DATA_DIR
cd $DATA_DIR
vis_demo_list='se_resnext50 ocr mobilenet'
for vis_demo_name in $vis_demo_list; do
download $vis_demo_name
done
cd ..
# compile and test the demo
cd $current_dir
mkdir -p build
cd build
......@@ -73,9 +84,9 @@ for WITH_STATIC_LIB in ON OFF; do
for use_gpu in $use_gpu_list; do
for vis_demo_name in $vis_demo_list; do
./vis_demo \
--modeldir=../data/$vis_demo_name/model \
--data=../data/$vis_demo_name/data.txt \
--refer=../data/$vis_demo_name/result.txt \
--modeldir=$DATA_DIR/$vis_demo_name/model \
--data=$DATA_DIR/$vis_demo_name/data.txt \
--refer=$DATA_DIR/$vis_demo_name/result.txt \
--use_gpu=$use_gpu
if [ $? -ne 0 ]; then
echo "vis demo $vis_demo_name runs fail."
......@@ -83,5 +94,25 @@ for WITH_STATIC_LIB in ON OFF; do
fi
done
done
# --------tensorrt mobilenet------
if [ $USE_TENSORRT == ON -a $TEST_GPU_CPU == ON ]; then
rm -rf *
cmake .. -DPADDLE_LIB=${PADDLE_ROOT}/build/fluid_install_dir/ \
-DWITH_MKL=$TURN_ON_MKL \
-DDEMO_NAME=vis_demo \
-DWITH_GPU=$TEST_GPU_CPU \
-DWITH_STATIC_LIB=$WITH_STATIC_LIB \
-DUSE_TENSORRT=$USE_TENSORRT \
-DTENSORRT_INCLUDE_DIR=$TENSORRT_INCLUDE_DIR \
-DTENSORRT_LIB_DIR=$TENSORRT_LIB_DIR
make -j
./vis_demo \
--modeldir=$DATA_DIR/mobilenet/model \
--data=$DATA_DIR/mobilenet/data.txt \
--refer=$DATA_DIR/mobilenet/result.txt \
--use_gpu=true \
--use_trt=true
fi
done
set +x
......@@ -33,6 +33,7 @@ DEFINE_string(
"path of data; each line is a record, format is "
"'<space splitted floats as data>\t<space splitted ints as shape'");
DEFINE_bool(use_gpu, false, "Whether use gpu.");
DEFINE_bool(use_trt, false, "Whether use trt.");
namespace paddle {
namespace demo {
......@@ -100,7 +101,9 @@ void CheckOutput(const std::string& referfile, const PaddleTensor& output) {
/*
* Use the native fluid engine to inference the demo.
*/
void Main(bool use_gpu) {
void Main(bool use_gpu, bool use_trt) {
std::unique_ptr<PaddlePredictor> predictor;
if (!use_trt) {
NativeConfig config;
config.param_file = FLAGS_modeldir + "/__params__";
config.prog_file = FLAGS_modeldir + "/__model__";
......@@ -111,8 +114,18 @@ void Main(bool use_gpu) {
}
VLOG(3) << "init predictor";
auto predictor =
predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
} else {
paddle::contrib::MixedRTConfig config;
config.param_file = FLAGS_modeldir + "/__params__";
config.prog_file = FLAGS_modeldir + "/__model__";
config.use_gpu = true;
config.device = 0;
config.max_batch_size = 1;
config.fraction_of_gpu_memory = 0.1; // set by yourself
predictor = CreatePaddlePredictor<paddle::contrib::MixedRTConfig>(config);
}
VLOG(3) << "begin to process data";
// Just a single batch of data.
......@@ -131,7 +144,7 @@ void Main(bool use_gpu) {
VLOG(3) << "run executor";
std::vector<PaddleTensor> output;
predictor->Run({input}, &output);
predictor->Run({input}, &output, 1);
VLOG(3) << "output.size " << output.size();
auto& tensor = output.front();
......@@ -146,9 +159,12 @@ void Main(bool use_gpu) {
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
paddle::demo::Main(false /* use_gpu*/);
if (FLAGS_use_gpu) {
paddle::demo::Main(true /*use_gpu*/);
if (FLAGS_use_gpu && FLAGS_use_trt) {
paddle::demo::Main(true /*use_gpu*/, true);
} else if (FLAGS_use_gpu) {
paddle::demo::Main(true /*use_gpu*/, false);
} else {
paddle::demo::Main(false /*use_gpu*/, false /*use_tensorrt*/);
}
return 0;
}
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class AdadeltaOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -31,6 +32,16 @@ class AdadeltaOp : public framework::OperatorWithKernel {
"Input(AvgSquaredGrad) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("AvgSquaredUpdate"),
"Input(AvgSquaredUpdate) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdadeltaOp should not be null.");
......@@ -56,6 +67,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
......
......@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class AdadeltaOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto avg_squared_grad_out_tensor =
ctx.Output<framework::Tensor>("AvgSquaredGradOut");
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -21,25 +22,31 @@ namespace operators {
template <typename DeviceContext, typename T>
struct SparseAdagradFunctor {
void operator()(const DeviceContext& context,
const framework::SelectedRows& grad,
const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param);
void operator()(const DeviceContext &context,
const framework::SelectedRows &grad,
const framework::Tensor &learning_rate, T epsilon,
framework::Tensor *moment, framework::Tensor *param);
};
template <typename DeviceContext, typename T>
class AdagradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto* moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
auto *param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto *moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
param_out_tensor->mutable_data<T>(ctx.GetPlace());
moment_out_tensor->mutable_data<T>(ctx.GetPlace());
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto* grad_var = ctx.InputVar("Grad");
auto *grad_var = ctx.InputVar("Grad");
if (grad_var->IsType<framework::LoDTensor>()) {
auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param"));
......@@ -47,16 +54,16 @@ class AdagradOpKernel : public framework::OpKernel<T> {
*ctx.Input<framework::Tensor>("Grad"));
auto moment = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment"));
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
auto *place = ctx.template device_context<DeviceContext>().eigen_device();
moment_out.device(*place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
if (platform::is_cpu_place(ctx.GetPlace())) {
auto* lr = learning_rate->data<T>();
auto *lr = learning_rate->data<T>();
param_out.device(*place) =
param - lr[0] * grad / (moment_out.sqrt() + epsilon);
} else {
......@@ -66,10 +73,10 @@ class AdagradOpKernel : public framework::OpKernel<T> {
lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto* param_tensor = ctx.Input<framework::Tensor>("Param");
auto *param_tensor = ctx.Input<framework::Tensor>("Param");
PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor);
auto* moment_tensor = ctx.Input<framework::Tensor>("Moment");
auto *moment_tensor = ctx.Input<framework::Tensor>("Moment");
PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor);
SparseAdagradFunctor<DeviceContext, T> functor;
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
......@@ -199,23 +200,9 @@ struct SparseAdamFunctor {
row_numel_(row_numel),
row_count_(row_count) {}
inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const {
int64_t beg = 0, end = row_count_ - 1;
while (beg <= end) {
auto mid = ((beg + end) >> 1);
if (rows_[mid] == row)
return mid;
else if (rows_[mid] < row)
beg = mid + 1;
else
end = mid - 1;
}
return -1;
}
inline HOSTDEVICE void operator()(size_t i) const {
int64_t row = i / row_numel_;
auto row_idx = BinarySearchInRows(row);
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
// The following code is the same as dense
......@@ -244,6 +231,12 @@ template <typename DeviceContext, typename T>
class AdamOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
using paddle::framework::LoDTensor;
using paddle::operators::detail::Ref;
......
......@@ -35,6 +35,16 @@ class AdamaxOp : public framework::OperatorWithKernel {
"Input(LearningRate) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
"Input(Beta1Pow) of AdamaxOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdamaxOp should not be null.");
......
......@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class AdamaxOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
auto inf_norm_out_tensor = ctx.Output<framework::Tensor>("InfNormOut");
......
......@@ -32,6 +32,16 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(
ctx->HasInput("LearningRate"),
"Input(LearningRate) of DecayedAdagradOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of DecayedAdagradOp should not be null.");
......
......@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class DecayedAdagradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
......
......@@ -34,6 +34,16 @@ class FTRLOp : public framework::OperatorWithKernel {
"Input(Grad) of FTRL should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of FTRL should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of FTRL should not be null.");
......
......@@ -28,6 +28,17 @@ template <typename DeviceContext, typename T>
class FTRLOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* sq_accum_out = ctx.Output<Tensor>("SquaredAccumOut");
auto* lin_accum_out = ctx.Output<Tensor>("LinearAccumOut");
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstdint> // for int64_t
#include <numeric>
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) {
int64_t beg = 0, end = num - 1;
while (beg <= end) {
auto mid = ((beg + end) >> 1);
if (x[mid] == val)
return mid;
else if (x[mid] < val)
beg = mid + 1;
else
end = mid - 1;
}
return -1;
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -24,7 +24,7 @@ class MomentumOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(param) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
......@@ -33,6 +33,11 @@ class MomentumOp : public framework::OperatorWithKernel {
"Input(velocity) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of Momentum should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of Momentum should not be null.");
......@@ -40,12 +45,15 @@ class MomentumOp : public framework::OperatorWithKernel {
"Output(VelocityOut) of Momentum should not be null.");
auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and Grad input of MomentumOp should have the same dimension.");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Velocity"),
"Param and Velocity of MomentumOp should have the same dimension.");
}
PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1,
"Learning_rate should be a scalar");
......@@ -53,13 +61,34 @@ class MomentumOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("VelocityOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
const framework::ExecutionContext& ctx) const override {
auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class MomentumOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto input_var = op_desc.Input("Param")[0];
for (auto& out_var : op_desc.Output("ParamOut")) {
if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::SELECTED_ROWS);
} else if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::LOD_TENSOR) {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::LOD_TENSOR);
} else {
PADDLE_THROW(
"Only support LodTensor and SelectedRows, Unexpected Input Type.");
}
}
}
};
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -110,6 +139,9 @@ $$
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(momentum, ops::MomentumOp, ops::MomentumOpMaker);
REGISTER_OP_CPU_KERNEL(momentum, ops::MomentumOpKernel<float>,
ops::MomentumOpKernel<double>);
REGISTER_OPERATOR(momentum, ops::MomentumOp, ops::MomentumOpMaker,
paddle::framework::EmptyGradOpMaker,
ops::MomentumOpInferVarType);
REGISTER_OP_CPU_KERNEL(
momentum, ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -15,65 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/momentum_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void MomentumKernel(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu,
const int64_t num, bool use_nesterov, T* p_out,
T* v_out) {
T lr = learning_rate[0];
if (use_nesterov) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
T g_val = g[i];
T v_new = v[i] * mu + g_val;
v_out[i] = v_new;
p_out[i] = p[i] - (g_val + v_new * mu) * lr;
}
} else {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
T v_new = v[i] * mu + g[i];
v_out[i] = v_new;
p_out[i] = p[i] - lr * v_new;
}
}
}
template <typename T>
class MomentumOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
auto param = ctx.Input<framework::Tensor>("Param");
auto velocity = ctx.Input<framework::Tensor>("Velocity");
auto grad = ctx.Input<framework::Tensor>("Grad");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
T mu = static_cast<T>(ctx.Attr<float>("mu"));
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto* p = param->data<T>();
auto* v = velocity->data<T>();
auto* g = grad->data<T>();
auto* lr = learning_rate->data<T>();
int block = 512;
int grid = (param->numel() + block - 1) / block;
MomentumKernel<T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
p, g, v, lr, mu, param->numel(), use_nesterov, p_out, v_out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(momentum, ops::MomentumOpCUDAKernel<float>,
ops::MomentumOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
momentum, ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -13,29 +13,48 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename T>
class MomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
auto param = ctx.Input<framework::Tensor>("Param");
auto velocity = ctx.Input<framework::Tensor>("Velocity");
auto grad = ctx.Input<framework::Tensor>("Grad");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
using framework::Tensor;
using framework::SelectedRows;
struct NoNesterov;
struct UseNesterov;
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
template <typename T>
class CPUDenseMomentumFunctor {
private:
const Tensor* param;
const Tensor* grad;
const Tensor* velocity;
const Tensor* learning_rate;
const T mu;
const T use_nesterov;
Tensor* param_out;
Tensor* velocity_out;
T mu = static_cast<T>(ctx.Attr<float>("mu"));
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
public:
CPUDenseMomentumFunctor(const Tensor* param, const Tensor* grad,
const Tensor* velocity, const Tensor* learning_rate,
const T mu, const bool use_nesterov,
Tensor* param_out, Tensor* velocity_out)
: param(param),
grad(grad),
velocity(velocity),
learning_rate(learning_rate),
mu(mu),
use_nesterov(use_nesterov),
param_out(param_out),
velocity_out(velocity_out) {}
inline void operator()() {
auto p_out = framework::EigenVector<T>::Flatten(*param_out);
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);
......@@ -53,5 +72,283 @@ class MomentumOpKernel : public framework::OpKernel<T> {
}
};
template <typename T, typename UpdateMethod>
class DenseMomentumFunctor;
// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
template <typename T>
class DenseMomentumFunctor<T, UseNesterov> {
private:
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const T mu_;
const int64_t num_;
T* p_out_;
T* v_out_;
public:
DenseMomentumFunctor(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu, const int64_t num,
T* p_out, T* v_out)
: p_(p),
g_(g),
v_(v),
lr_(learning_rate),
mu_(mu),
num_(num),
p_out_(p_out),
v_out_(v_out) {}
inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register
const T p = p_[i];
const T g = g_[i];
const T lr = lr_[0];
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - (g + v_out * mu_) * lr;
// write reigster to memory
v_out_[i] = v_out;
p_out_[i] = p_out;
}
};
template <typename T>
class DenseMomentumFunctor<T, NoNesterov> {
private:
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const T mu_;
const int64_t num_;
T* p_out_;
T* v_out_;
public:
DenseMomentumFunctor(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu, const int64_t num,
T* p_out, T* v_out)
: p_(p),
g_(g),
v_(v),
lr_(learning_rate),
mu_(mu),
num_(num),
p_out_(p_out),
v_out_(v_out) {}
inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register
const T p = p_[i];
const T g = g_[i];
const T lr = lr_[0];
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - lr * v_out;
// write reigster to memory
v_out_[i] = v_out;
p_out_[i] = p_out;
}
};
template <typename T, typename UpdateMethod>
class SparseMomentumFunctor;
template <typename T>
class SparseMomentumFunctor<T, UseNesterov> {
private:
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const T mu_;
const int64_t* rows_;
const int64_t row_numel_;
const int64_t row_height_;
T* p_out_;
T* v_out_;
public:
SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr,
const T mu, const int64_t* rows, int64_t row_numel,
int64_t row_height, T* p_out, T* v_out)
: p_(p),
g_(g),
v_(v),
lr_(lr),
mu_(mu),
rows_(rows),
row_numel_(row_numel),
row_height_(row_height),
p_out_(p_out),
v_out_(v_out) {}
inline HOSTDEVICE void operator()(size_t i) {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0;
// put memory access in register
const T p = p_[i];
const T lr = lr_[0];
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - (g + v_out * mu_) * lr;
// write reigster to memory
v_out_[i] = v_out;
p_out_[i] = p_out;
}
};
template <typename T>
class SparseMomentumFunctor<T, NoNesterov> {
private:
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const T mu_;
const int64_t* rows_;
const int64_t row_numel_;
const int64_t row_height_;
T* p_out_;
T* v_out_;
public:
SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr,
const T mu, const int64_t* rows, int64_t row_numel,
int64_t row_height, T* p_out, T* v_out)
: p_(p),
g_(g),
v_(v),
lr_(lr),
mu_(mu),
rows_(rows),
row_numel_(row_numel),
row_height_(row_height),
p_out_(p_out),
v_out_(v_out) {}
inline HOSTDEVICE void operator()(size_t i) {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0;
// put memory access in register
const T p = p_[i];
const T lr = lr_[0];
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - v_out * lr;
// write reigster to memory
v_out_[i] = v_out;
p_out_[i] = p_out;
}
};
template <typename DeviceContext, typename T>
class MomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
T mu = static_cast<T>(ctx.Attr<float>("mu"));
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto param = ctx.Input<framework::Tensor>("Param");
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* velocity = ctx.Input<framework::Tensor>("Velocity");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
auto* grad_var = ctx.InputVar("Grad");
if (grad_var->IsType<framework::LoDTensor>()) {
auto grad = ctx.Input<framework::Tensor>("Grad");
if (platform::is_cpu_place(ctx.GetPlace())) {
CPUDenseMomentumFunctor<T> functor(param, grad, velocity, learning_rate,
mu, use_nesterov, param_out,
velocity_out);
functor();
} else if (platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param->numel());
if (use_nesterov) {
DenseMomentumFunctor<T, UseNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<T>(),
learning_rate->data<T>(), mu, param->numel(),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
} else {
DenseMomentumFunctor<T, NoNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<T>(),
learning_rate->data<T>(), mu, param->numel(),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
}
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
// sparse update embedding with selectedrows
auto grad = ctx.Input<framework::SelectedRows>("Grad");
// sparse update maybe empty.
if (grad->rows().size() == 0) {
VLOG(3) << "Grad SelectedRows contains no data!";
return;
}
auto* merged_grad = const_cast<framework::Scope&>(ctx.scope())
.Var()
->GetMutable<framework::SelectedRows>();
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(ctx.template device_context<DeviceContext>(), *grad,
merged_grad);
const int64_t* rows = nullptr;
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
rows = merged_grad->rows().CUDAData(ctx.GetPlace());
} else {
#endif
rows = merged_grad->rows().data();
#ifdef PADDLE_WITH_CUDA
}
#endif
int64_t row_numel =
merged_grad->value().numel() / merged_grad->rows().size();
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param->numel());
if (use_nesterov) {
SparseMomentumFunctor<T, UseNesterov> functor(
param->data<T>(), merged_grad->value().data<T>(),
velocity->data<T>(), learning_rate->data<T>(), mu, rows, row_numel,
static_cast<int64_t>(merged_grad->rows().size()),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
} else {
SparseMomentumFunctor<T, NoNesterov> functor(
param->data<T>(), merged_grad->value().data<T>(),
velocity->data<T>(), learning_rate->data<T>(), mu, rows, row_numel,
static_cast<int64_t>(merged_grad->rows().size()),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
}
} else {
PADDLE_THROW(
string::Sprintf("MomentumOp only supports LoDTensor or SelectedRows "
"gradient, but the received Variable Type is %s",
grad_var->Type().name()));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -32,6 +32,11 @@ class RmspropOp : public framework::OperatorWithKernel {
"Input(Grad) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(Moment) of RmspropOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of RmspropOp should not be null.");
......
......@@ -13,66 +13,259 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
struct DenseRmspropGradFunctor {
inline explicit DenseRmspropGradFunctor(const T *grad) : grad_(grad) {}
HOSTDEVICE inline T operator()(int64_t idx) const { return grad_[idx]; }
const T *grad_;
};
template <typename T>
struct SparseRmspropGradFunctor {
inline SparseRmspropGradFunctor(const T *grad, const int64_t *rows,
int64_t row_numel, int64_t row_count)
: grad_(grad),
rows_(rows),
row_numel_(row_numel),
row_count_(row_count) {}
HOSTDEVICE inline T operator()(int64_t idx) const {
auto row_idx = math::BinarySearch(rows_, row_count_, idx / row_numel_);
return row_idx >= 0 ? grad_[row_idx * row_numel_ + idx % row_numel_] : 0;
}
const T *grad_;
const int64_t *rows_;
int64_t row_numel_;
int64_t row_count_;
};
template <typename T, typename GradFunctor>
struct UncenteredRmspropFunctor {
UncenteredRmspropFunctor(T *param, T *ms, T *mom, const T *lr, T rho,
T epsilon, T momentum,
const GradFunctor &grad_functor)
: param_(param),
ms_(ms),
mom_(mom),
lr_(lr),
rho_(rho),
epsilon_(epsilon),
momentum_(momentum),
grad_functor_(grad_functor) {}
HOSTDEVICE inline void operator()(int64_t idx) const {
T g = grad_functor_(idx);
T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g;
T mom_out = momentum_ * mom_[idx] + lr_[0] * g / sqrt(ms_out + epsilon_);
param_[idx] -= mom_out;
ms_[idx] = ms_out;
mom_[idx] = mom_out;
}
T *param_;
T *ms_;
T *mom_;
const T *lr_;
T rho_;
T epsilon_;
T momentum_;
GradFunctor grad_functor_;
};
template <typename T, typename GradFunctor>
struct CenteredRmspropFunctor {
CenteredRmspropFunctor(T *param, T *ms, T *mom, T *mean_grad, const T *lr,
T rho, T epsilon, T momentum,
const GradFunctor &grad_functor)
: param_(param),
ms_(ms),
mom_(mom),
mean_grad_(mean_grad),
lr_(lr),
rho_(rho),
epsilon_(epsilon),
momentum_(momentum),
grad_functor_(grad_functor) {}
HOSTDEVICE inline void operator()(int64_t idx) const {
T g = grad_functor_(idx);
T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g;
T mg_out = rho_ * mean_grad_[idx] + (1 - rho_) * g;
T mom_out = momentum_ * mom_[idx] +
lr_[0] * g / sqrt(ms_out - mg_out * mg_out + epsilon_);
param_[idx] -= mom_out;
ms_[idx] = ms_out;
mom_[idx] = mom_out;
mean_grad_[idx] = mg_out;
}
T *param_;
T *ms_;
T *mom_;
T *mean_grad_;
const T *lr_;
T rho_;
T epsilon_;
T momentum_;
GradFunctor grad_functor_;
};
template <typename DeviceContext, typename T>
class RmspropOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* moment_out = ctx.Output<Tensor>("MomentOut");
auto* mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
void Compute(const framework::ExecutionContext &ctx) const override {
using LoDTensor = framework::LoDTensor;
const auto *param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
auto *grad_var = ctx.InputVar("Grad");
auto *param_out = ctx.Output<LoDTensor>("ParamOut");
auto *moment_out = ctx.Output<LoDTensor>("MomentOut");
auto *mean_square_out = ctx.Output<LoDTensor>("MeanSquareOut");
auto grad = ctx.Input<Tensor>("Grad");
auto epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto rho = static_cast<T>(ctx.Attr<float>("decay"));
auto momentum = static_cast<T>(ctx.Attr<float>("momentum"));
bool centered = ctx.Attr<bool>("centered");
param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace());
mean_square_out->mutable_data<T>(ctx.GetPlace());
auto &p_tensor = *ctx.Input<LoDTensor>("Param");
auto &ms_tensor = *ctx.Input<LoDTensor>("MeanSquare");
auto &lr_tensor = *ctx.Input<LoDTensor>("LearningRate");
auto &mom_tensor = *ctx.Input<LoDTensor>("Moment");
float epsilon = ctx.Attr<float>("epsilon");
float rho = ctx.Attr<float>("decay");
float momentum = ctx.Attr<float>("momentum");
bool centered = ctx.Attr<bool>("centered");
PADDLE_ENFORCE_EQ(&p_tensor, param_out,
"Param and ParamOut must be the same Tensor");
PADDLE_ENFORCE_EQ(&mom_tensor, moment_out,
"Moment and MomentOut must be the same Tensor");
PADDLE_ENFORCE_EQ(&ms_tensor, mean_square_out,
"MeanSquare and MeanSquareOut must be the same Tensor");
auto &dev_ctx = ctx.template device_context<DeviceContext>();
size_t limit = static_cast<size_t>(ms_tensor.numel());
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare"));
auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate"));
auto g = EigenVector<T>::Flatten(*grad);
auto mom = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
if (grad_var->IsType<LoDTensor>()) {
auto &grad_tensor = grad_var->Get<LoDTensor>();
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value) {
auto &place =
*ctx.template device_context<DeviceContext>().eigen_device();
auto lr_value = lr_tensor.data<T>()[0];
auto p = EigenVector<T>::Flatten(p_tensor);
auto ms = EigenVector<T>::Flatten(ms_tensor);
auto g = EigenVector<T>::Flatten(grad_tensor);
auto mom = EigenVector<T>::Flatten(mom_tensor);
auto p_out = EigenVector<T>::Flatten(*param_out);
auto mom_out = EigenVector<T>::Flatten(*moment_out);
auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 1> grad_dsize(static_cast<int>(grad->numel()));
ms_out.device(place) = rho * ms + (1 - rho) * g * g;
if (centered) {
auto mg = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanGrad"));
auto* mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
mean_grad_out->mutable_data<T>(ctx.GetPlace());
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto mg = EigenVector<T>::Flatten(mg_tensor);
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
mg_out.device(place) = rho * mg + (1 - rho) * g;
mom_out.device(place) = momentum * mom +
lr.broadcast(grad_dsize) * g /
(ms_out - mg_out.square() + epsilon).sqrt();
} else {
mom_out.device(place) =
momentum * mom +
lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt();
lr_value * g / (ms_out - mg_out.square() + epsilon).sqrt();
} else {
mom_out.device(place) =
momentum * mom + lr_value * g / (ms_out + epsilon).sqrt();
}
p_out.device(place) = p - mom_out;
} else {
DenseRmspropGradFunctor<T> grad_func(grad_tensor.data<T>());
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
if (centered) {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()),
moment_out->mutable_data<T>(ctx.GetPlace()),
mean_grad_out->mutable_data<T>(ctx.GetPlace()),
lr_tensor.data<T>(), rho, epsilon, momentum, grad_func));
} else {
for_range(UncenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()),
moment_out->mutable_data<T>(ctx.GetPlace()), lr_tensor.data<T>(),
rho, epsilon, momentum, grad_func));
}
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto &grad = grad_var->Get<framework::SelectedRows>();
auto *merged_grad = const_cast<framework::Scope &>(ctx.scope())
.Var()
->GetMutable<framework::SelectedRows>();
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(dev_ctx, grad, merged_grad);
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
const int64_t *rows;
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
rows = merged_grad->rows().CUDAData(ctx.GetPlace());
} else {
#endif
rows = merged_grad->rows().data();
#ifdef PADDLE_WITH_CUDA
}
#endif
auto &merged_tensor = merged_grad->value();
int64_t row_count = merged_grad->rows().size();
int64_t row_numel = merged_tensor.numel() / row_count;
SparseRmspropGradFunctor<T> grad_func(merged_tensor.data<T>(), rows,
row_numel, row_count);
if (centered) {
auto &mg_tensor = *ctx.Input<LoDTensor>("MeanGrad");
auto *mean_grad_out = ctx.Output<LoDTensor>("MeanGradOut");
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
"MeanGrad and MeanGradOut must be the same Tensor");
for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()),
moment_out->mutable_data<T>(ctx.GetPlace()),
mean_grad_out->mutable_data<T>(ctx.GetPlace()), lr_tensor.data<T>(),
rho, epsilon, momentum, grad_func));
} else {
for_range(UncenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
param_out->mutable_data<T>(ctx.GetPlace()),
mean_square_out->mutable_data<T>(ctx.GetPlace()),
moment_out->mutable_data<T>(ctx.GetPlace()), lr_tensor.data<T>(),
rho, epsilon, momentum, grad_func));
}
} else {
PADDLE_THROW("RMSProp only supports LoDTensor or SelectedRows gradient");
}
}
};
......
......@@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
......@@ -42,7 +42,7 @@ class SGDOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::ExecutionContext &ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
return framework::OpKernelType(data_type, ctx.device_context());
}
......@@ -50,17 +50,20 @@ class SGDOp : public framework::OperatorWithKernel {
class SGDOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto input_var = op_desc.Input("Param")[0];
for (auto& out_var : op_desc.Output("ParamOut")) {
if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::SELECTED_ROWS);
} else {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::LOD_TENSOR);
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto input_var_n = op_desc.Input("Param")[0];
auto in_var_type = block->FindRecursiveOrCreateVar(input_var_n).GetType();
PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
in_var_type == framework::proto::VarType::LOD_TENSOR,
"The input Var's type should be LoDtensor or SelectedRows,"
" but the received var(%s)'s type is %s",
input_var_n, in_var_type);
for (auto &out_var_n : op_desc.Output("ParamOut")) {
auto &out_var = block->FindRecursiveOrCreateVar(out_var_n);
if (out_var.GetType() != in_var_type) {
out_var.SetType(in_var_type);
}
}
}
......
......@@ -57,6 +57,12 @@ template <typename T>
class SGDOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
auto* param = ctx.Input<framework::Tensor>("Param");
auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
......
......@@ -23,14 +23,14 @@ namespace operators {
template <typename T>
class CPUUniformRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
framework::Tensor* tensor = nullptr;
void Compute(const framework::ExecutionContext &ctx) const override {
framework::Tensor *tensor = nullptr;
auto out_var = ctx.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();
} else if (out_var->IsType<framework::SelectedRows>()) {
auto shape = ctx.Attr<std::vector<int>>("shape");
auto* selected_rows = out_var->GetMutable<framework::SelectedRows>();
auto *selected_rows = out_var->GetMutable<framework::SelectedRows>();
tensor = selected_rows->mutable_value();
tensor->Resize(framework::make_ddim(shape));
selected_rows->mutable_rows()->reserve(shape[0]);
......@@ -39,7 +39,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
"uniform_random_op's output only"
"supports SelectedRows and LoDTensor");
}
T* data = tensor->mutable_data<T>(ctx.GetPlace());
T *data = tensor->mutable_data<T>(ctx.GetPlace());
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
......@@ -60,14 +60,14 @@ class UniformRandomOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UniformRandomOp should not be null.");
PADDLE_ENFORCE(
ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
"uniform_random's min must less then max");
auto& shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
std::vector<int64_t> temp;
temp.reserve(shape.size());
for (auto dim : shape) {
......@@ -78,7 +78,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
......@@ -112,17 +112,17 @@ uniform distribution. The random result is in set [min, max].
class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Out").front();
if (block->FindRecursiveOrCreateVar(out_var_name).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(framework::proto::VarType::SELECTED_ROWS);
} else {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(framework::proto::VarType::LOD_TENSOR);
auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(op_desc.GetAttr("dtype")));
auto out_var = block->FindRecursiveOrCreateVar(out_var_name);
if (out_var.GetType() != framework::proto::VarType::SELECTED_ROWS) {
out_var.SetType(framework::proto::VarType::LOD_TENSOR);
}
out_var.SetDataType(var_data_type);
}
};
......
......@@ -156,7 +156,50 @@ PYBIND11_PLUGIN(core) {
.def("_get_double_element", TensorGetElement<double>)
.def("_dtype", [](Tensor &self) { return ToDataType(self.type()); });
py::class_<LoDTensor, Tensor>(m, "LoDTensor")
py::class_<LoDTensor, Tensor>(m, "LoDTensor", R"DOC(
LoDTensor is a Tensor with optional LoD information.
np.array(lod_tensor) can convert LoDTensor to numpy array.
lod_tensor.lod() can retrieve the LoD information.
LoD is short for Level of Details and is usually used for varied sequence
length. You can skip the following comment if you don't need optional LoD.
For example:
A LoDTensor X can look like the example below. It contains 2 sequences.
The first has length 2 and the second has length 3, as described by x.lod.
The first tensor dimension 5=2+3 is calculated from LoD if it's available.
It means the total number of sequence element. In X, each element has 2
columns, hence [5, 2].
x.lod = [[2, 3]]
x.data = [[1, 2], [3, 4], // seq 1
[5, 6], [7, 8], [9, 10]] // seq 2
x.shape = [5, 2]
LoD can have multiple levels (for example, a paragraph can have multiple
sentences and a sentence can have multiple words). In the following
LodTensor Y, the lod_level is 2. It means there are 2 sequence, the
first sequence length is 2 (has 2 sub-sequences), the second one's
length is 1. The first sequence's 2 sub-sequences have length 2 and 2,
respectively. And the second sequence's 1 sub-sequence has length 3.
y.lod = [[2 1], [2 2 3]]
y.shape = [2+2+3, ...]
Note:
In above description, LoD is length-based. In Paddle internal
implementation, lod is offset-based. Hence, internally,
y.lod is represented as [[0, 2, 3], [0, 2, 4, 7]] (length-based
equivlent would be [[2-0, 3-2], [2-0, 4-2, 7-4]]).
Sometimes LoD is called recursive_sequence_length to be more
self-explanatory. In this case, it must be length-based. Due to history
reasons. when LoD is called lod in public API, it might be offset-based.
Users should be careful about it.
)DOC")
.def_buffer(
[](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); })
.def("__init__",
......@@ -596,26 +639,58 @@ All parameter, weight, gradient are variables in Paddle.
// -- python binds for parallel executor.
py::class_<ParallelExecutor> pe(m, "ParallelExecutor");
py::class_<ExecutionStrategy> exec_strategy(pe, "ExecutionStrategy");
py::class_<ExecutionStrategy> exec_strategy(pe, "ExecutionStrategy", R"DOC(
ExecutionStrategy allows the user to more preciously control how to run
the program in ParallelExecutor by setting the property.
Examples:
.. code-block:: python
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 4
train_exe = fluid.ParallelExecutor(use_cuda=True,
loss_name=loss.name,
exec_strategy=exec_strategy)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
)DOC");
exec_strategy.def(py::init())
.def_property(
"num_threads",
[](const ExecutionStrategy &self) { return self.num_threads_; },
[](ExecutionStrategy &self, size_t num_threads) {
self.num_threads_ = num_threads;
})
},
R"DOC(The type is INT, num_threads represents the size of thread pool that
used to run the operators of the current program in ParallelExecutor.
If :math:`num\_threads=1`, all the operators will execute one by one,
but the order maybe difference between iterations.
If it is not set, it will be set in ParallelExecutor according to the
device type and device count, for GPU, :math:`num\_threads=device\_count*4`, for CPU,
:math:`num\_threads=CPU\_NUM*4`, the explanation of:math:`CPU\_NUM` is in ParallelExecutor.
if it is not set, ParallelExecutor will get the cpu count by calling
`multiprocessing.cpu_count()`. Default 0.)DOC")
.def_property(
"use_cuda",
[](const ExecutionStrategy &self) { return self.use_cuda_; },
[](ExecutionStrategy &self, bool use_cuda) {
self.use_cuda_ = use_cuda;
})
}) // FIXME(chengduo): Doesn't add doc for 'use_cuda', use_cuda may
// make user confuse, because ParallelExecutor has a parameter named
// 'use_cuda' too, in current implementation, ParallelExecutor's
// 'use_cuda' will rewrite ExecutionStrategy's 'use_cuda'.
.def_property(
"allow_op_delay",
[](const ExecutionStrategy &self) { return self.allow_op_delay_; },
[](ExecutionStrategy &self, bool allow_op_delay) {
self.allow_op_delay_ = allow_op_delay;
})
},
R"DOC(The type is BOOL, allow_op_delay represents whether to delay the
communication operators to run, it may make the execution faster.
Note that in some models, allow_op_delay may cause program hang. Default False.)DOC")
.def_property(
"num_iteration_per_drop_scope",
[](const ExecutionStrategy &self) {
......@@ -623,7 +698,19 @@ All parameter, weight, gradient are variables in Paddle.
},
[](ExecutionStrategy &self, size_t num_iteration_per_drop_scope) {
self.num_iteration_per_drop_scope_ = num_iteration_per_drop_scope;
});
},
R"DOC(The type is INT, num_iteration_per_drop_scope indicates how
many iterations to clean up the temp variables which
is generated during execution. It may make the execution faster,
because the temp variable's shape maybe the same between two iterations. Default 100.
NOTES:
1. If you fetch data when calling the 'run', the ParallelExecutor
will clean up the temp variables at the end of the current iteration.
2. In some NLP model, it may cause the GPU memory is insufficient,
in this case, you should reduce `num_iteration_per_drop_scope`.
)DOC");
exec_strategy.def_property(
"use_experimental_executor",
[](const ExecutionStrategy &self) {
......@@ -634,7 +721,22 @@ All parameter, weight, gradient are variables in Paddle.
: ExecutionStrategy::kDefault;
});
py::class_<BuildStrategy> build_strategy(pe, "BuildStrategy");
py::class_<BuildStrategy> build_strategy(pe, "BuildStrategy", R"DOC(
BuildStrategy allows the user to more preciously control how to
build the SSA Graph in ParallelExecutor by setting the property.
Examples:
.. code-block:: python
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
train_exe = fluid.ParallelExecutor(use_cuda=True,
loss_name=loss.name,
build_strategy=build_strategy)
train_loss, = train_exe.run([loss.name], feed=feed_dict)
)DOC");
py::enum_<BuildStrategy::ReduceStrategy>(build_strategy, "ReduceStrategy")
.value("Reduce", BuildStrategy::ReduceStrategy::kReduce)
......@@ -652,31 +754,51 @@ All parameter, weight, gradient are variables in Paddle.
[](const BuildStrategy &self) { return self.reduce_; },
[](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) {
self.reduce_ = strategy;
})
},
R"DOC(The type is STR, there are two reduce strategies in ParallelExecutor,
'AllReduce' and 'Reduce'. If you want that all the parameters'
optimization are done on all devices independently, you should choose 'AllReduce';
if you choose 'Reduce', all the parameters' optimization will be evenly distributed
to different devices, and then broadcast the optimized parameter to other devices.
In some models, `Reduce` is faster. Default 'AllReduce'. )DOC")
.def_property(
"gradient_scale_strategy",
[](const BuildStrategy &self) { return self.gradient_scale_; },
[](BuildStrategy &self,
BuildStrategy::GradientScaleStrategy strategy) {
self.gradient_scale_ = strategy;
})
},
R"DOC(The type is STR, there are three ways of defining :math:`loss@grad` in
ParallelExecutor, 'CoeffNumDevice', 'One' and 'Customized'. By default,
ParallelExecutor sets the :math:`loss@grad` according to the number of devices.
If you want to customize :math:`loss@grad`, you can choose 'Customized'.
Default 'CoeffNumDevice'.)DOC")
.def_property(
"debug_graphviz_path",
[](const BuildStrategy &self) { return self.debug_graphviz_path_; },
[](BuildStrategy &self, const std::string &path) {
self.debug_graphviz_path_ = path;
})
},
R"DOC(The type is STR, debug_graphviz_path indicate the path that
writing the SSA Graph to file in the form of graphviz, you.
It is useful for debugging. Default "")DOC")
.def_property(
"enable_data_balance",
[](const BuildStrategy &self) { return self.enable_data_balance_; },
[](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; })
.def_property("fuse_elewise_add_act_ops",
[](BuildStrategy &self, bool b) {
self.enable_data_balance_ = b;
}) // FIXME(chengudo): enable_data_balance seems not important
.def_property(
"fuse_elewise_add_act_ops",
[](const BuildStrategy &self) {
return self.fuse_elewise_add_act_ops_;
},
[](BuildStrategy &self, bool b) {
self.fuse_elewise_add_act_ops_ = b;
});
},
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
to fuse elementwise_add_op and activation_op,
it may make the execution faster. Default False)DOC");
pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &,
......
......@@ -654,11 +654,21 @@ function gen_fluid_inference_lib() {
if [[ ${WITH_C_API:-OFF} == "OFF" && ${WITH_INFERENCE:-ON} == "ON" ]] ; then
cat <<EOF
========================================
Deploying fluid inference library ...
Generating fluid inference library ...
========================================
EOF
cmake .. -DWITH_DISTRIBUTE=OFF
make -j `nproc` inference_lib_dist
fi
}
function tar_fluid_inference_lib() {
if [[ ${WITH_C_API:-OFF} == "OFF" && ${WITH_INFERENCE:-ON} == "ON" ]] ; then
cat <<EOF
========================================
Taring fluid inference library ...
========================================
EOF
cd ${PADDLE_ROOT}/build
cp -r fluid_install_dir fluid
tar -czf fluid.tgz fluid
......@@ -673,7 +683,7 @@ function test_fluid_inference_lib() {
========================================
EOF
cd ${PADDLE_ROOT}/paddle/fluid/inference/api/demo_ci
./run.sh ${PADDLE_ROOT} ${WITH_MKL:-ON} ${WITH_GPU:-OFF}
./run.sh ${PADDLE_ROOT} ${WITH_MKL:-ON} ${WITH_GPU:-OFF} ${INFERENCE_DEMO_INSTALL_DIR} ${TENSORRT_INCLUDE_DIR:-/usr/local/TensorRT/include} ${TENSORRT_LIB_DIR:-/usr/local/TensorRT/lib}
./clean.sh
fi
}
......@@ -722,6 +732,7 @@ function main() {
fluid_inference_lib)
cmake_gen ${PYTHON_ABI:-""}
gen_fluid_inference_lib
tar_fluid_inference_lib
test_fluid_inference_lib
;;
check_style)
......
......@@ -55,7 +55,11 @@ def data(name,
Args:
name(str): The name/alias of the function
shape(list): Tuple declaring the shape.
append_batch_size(bool): Whether or not to append the data as a batch.
append_batch_size(bool):
1. If true, it prepends -1 to the shape.
For example if shape=[1], the resulting shape is [-1, 1].
2. If shape contains -1, such as shape=[1, -1],
append_batch_size will be enforced to be be False (ineffective).
dtype(int|float): The type of data : float32, float_16, int etc
type(VarType): The output type. By default it is LOD_TENSOR.
lod_level(int): The LoD Level. 0 means the input data is not a sequence.
......
......@@ -351,7 +351,6 @@ def dynamic_lstm(input,
c_0(Variable): The initial cell state is an optional input, default is zero.
This is a tensor with shape (N x D), where N is the
batch size. `h_0` and `c_0` can be NULL but only at the same time.
param_attr(ParamAttr|None): The parameter attribute for the learnable
hidden-hidden weights.
......@@ -359,6 +358,11 @@ def dynamic_lstm(input,
W_{fh}, W_{oh}`}
- The shape is (D x 4D), where D is the hidden
size.
If it is set to None or one attribute of ParamAttr,
dynamic_lstm will create ParamAttr as param_attr.
If the Initializer of the param_attr is not set, the
parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The bias attribute for the learnable bias
weights, which contains two parts, input-hidden
bias weights and peephole connections weights if
......@@ -371,6 +375,11 @@ def dynamic_lstm(input,
- Biases = { :math:`b_c, b_i, b_f, b_o, W_{ic}, \
W_{fc}, W_{oc}`}.
- The shape is (1 x 7D).
If it is set to None or one attribute of ParamAttr,
dynamic_lstm will create ParamAttr as bias_attr.
If the Initializer of the bias_attr is not set,
the bias is initialized zero. Default: None.
use_peepholes (bool): ${use_peepholes_comment}
is_reverse (bool): ${is_reverse_comment}
gate_activation (str): ${gate_activation_comment}
......@@ -389,11 +398,11 @@ def dynamic_lstm(input,
hidden_dim = 512
forward_proj = fluid.layers.fc(input=input_seq, size=hidden_dim * 4,
act=None, bias_attr=None)
bias_attr=False)
forward, _ = fluid.layers.dynamic_lstm(
input=forward_proj, size=hidden_dim * 4, use_peepholes=False)
"""
assert bias_attr is not False, "bias_attr should not be False in dynamic_lstmp."
helper = LayerHelper('lstm', **locals())
size = size // 4
weight = helper.create_parameter(
......@@ -528,6 +537,11 @@ def dynamic_lstmp(input,
size.
- Projection weight = {:math:`W_{rh}`}.
- The shape of projection weight is (D x P).
If it is set to None or one attribute of ParamAttr,
dynamic_lstm will create ParamAttr as param_attr.
If the Initializer of the param_attr is not set, the
parameter is initialized with Xavier. Default: None.
bias_attr(ParamAttr|None): The bias attribute for the learnable bias
weights, which contains two parts, input-hidden
bias weights and peephole connections weights if
......@@ -540,6 +554,11 @@ def dynamic_lstmp(input,
- Biases = { :math:`b_c, b_i, b_f, b_o, W_{ic}, \
W_{fc}, W_{oc}`}.
- The shape is (1 x 7D).
If it is set to None or one attribute of ParamAttr,
dynamic_lstm will create ParamAttr as bias_attr.
If the Initializer of the bias_attr is not set,
the bias is initialized zero. Default: None.
use_peepholes(bool): Whether to enable diagonal/peephole connections,
default `True`.
is_reverse(bool): Whether to compute reversed LSTM, default `False`.
......@@ -584,6 +603,7 @@ def dynamic_lstmp(input,
proj_activation="tanh")
"""
assert bias_attr is not False, "bias_attr should not be False in dynamic_lstmp."
helper = LayerHelper('lstmp', **locals())
size = size // 4
weight = helper.create_parameter(
......@@ -681,8 +701,18 @@ def dynamic_gru(input,
The first part are weights of the update gate and reset gate with
shape :math:`(D \\times 2D)`, and the second part are weights for
candidate hidden state with shape :math:`(D \\times D)`.
bias_attr(ParamAttr): The parameter attribute for learnable the
hidden-hidden bias.
If it is set to None or one attribute of ParamAttr, dynamic_gru will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias
of GRU. Note that the bias with :math:`(1 \\times 3D)` concatenates
the bias in the update gate, reset gate and candidate calculations.
If it is set to False, no bias will be applied to the update gate,
reset gate and candidate calculations. If it is set to None or one
attribute of ParamAttr, dynamic_gru will create ParamAttr as
bias_attr. If the Initializer of the bias_attr is not set, the bias
is initialized zero. Default: None.
is_reverse(bool): Whether to compute reversed GRU, default
:attr:`False`.
gate_activation(str): The activation for update gate and reset gate.
......@@ -781,10 +811,29 @@ def gru_unit(input,
Args:
input (Variable): The fc transformed input value of current step.
hidden (Variable): The hidden value of lstm unit from previous step.
hidden (Variable): The hidden value of gru unit from previous step.
size (integer): The input dimension value.
param_attr (ParamAttr): The weight parameters for gru unit. Default: None
bias_attr (ParamAttr): The bias parameters for gru unit. Default: None
param_attr(ParamAttr|None): The parameter attribute for the learnable
hidden-hidden weight matrix. Note:
- The shape of the weight matrix is :math:`(T \\times 3D)`, where
:math:`D` is the hidden size.
- All elements in the weight matrix can be divided into two parts.
The first part are weights of the update gate and reset gate with
shape :math:`(D \\times 2D)`, and the second part are weights for
candidate hidden state with shape :math:`(D \\times D)`.
If it is set to None or one attribute of ParamAttr, gru_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias
of GRU. Note that the bias with :math:`(1 \\times 3D)` concatenates
the bias in the update gate, reset gate and candidate calculations.
If it is set to False, no bias will be applied to the update gate,
reset gate and candidate calculations. If it is set to None or one
attribute of ParamAttr, gru_unit will create ParamAttr as
bias_attr. If the Initializer of the bias_attr is not set, the bias
is initialized zero. Default: None.
activation (string): The activation type for cell (actNode).
Default: 'tanh'
gate_activation (string): The activation type for gates (actGate).
......@@ -1265,7 +1314,8 @@ def sequence_conv(input,
padding=None,
bias_attr=None,
param_attr=None,
act=None):
act=None,
name=None):
"""
This function creates the op for sequence_conv, using the inputs and
other convolutional configurations for the filters and stride as given
......@@ -1277,9 +1327,19 @@ def sequence_conv(input,
filter_size (int): the filter size (H and W).
filter_stride (int): stride of the filter.
padding (bool): if True, add paddings.
bias_attr (ParamAttr|None): attributes for bias
param_attr (ParamAttr|None): attributes for parameter
act (str): the activation type
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of sequence_conv.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, sequence_conv
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of sequence_conv. If it is set to None or one attribute of ParamAttr, sequence_conv
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
act (str): Activation type, if it is set to None, activation is not appended.
Default: None.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None.
Returns:
Variable: output of sequence_conv
......@@ -1308,7 +1368,7 @@ def sequence_conv(input,
return helper.append_activation(pre_act)
def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=False):
def sequence_softmax(input, use_cudnn=False, name=None):
"""
This function computes the softmax activation among all time-steps for each
sequence. The dimension of each time-step should be 1. Thus, the shape of
......@@ -1328,10 +1388,10 @@ def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=False):
Args:
input (Variable): The input variable which is a LoDTensor.
bias_attr (ParamAttr|None): attributes for bias
param_attr (ParamAttr|None): attributes for parameter
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn \
library is installed. Default: False
library is installed. Default: False.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None.
Returns:
Variable: output of sequence_softmax
......@@ -1355,7 +1415,7 @@ def sequence_softmax(input, param_attr=None, bias_attr=None, use_cudnn=False):
return softmax_out
def softmax(input, param_attr=None, bias_attr=None, use_cudnn=True, name=None):
def softmax(input, use_cudnn=True, name=None):
"""
The input of the softmax operator is a tensor of any rank. The output tensor
has the same shape as the input.
......@@ -1382,10 +1442,10 @@ def softmax(input, param_attr=None, bias_attr=None, use_cudnn=True, name=None):
Args:
input (Variable): The input variable.
bias_attr (ParamAttr): attributes for bias
param_attr (ParamAttr): attributes for parameter
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn \
library is installed.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None.
Returns:
Variable: output of softmax
......@@ -1491,14 +1551,23 @@ def conv2d(input,
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1
param_attr (ParamAttr): The parameters to the Conv2d Layer. Default: None
bias_attr (ParamAttr): Bias parameter for the Conv2d layer. Default: None
connected to the second half of the input channels. Default: groups=1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act (str): Activation type. Default: None
act (str): Activation type, if it is set to None, activation is not appended.
Default: None
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
will be named automatically. Default: None
Returns:
Variable: The tensor variable storing the convolution and \
......@@ -1516,7 +1585,7 @@ def conv2d(input,
"""
num_channels = input.shape[1]
assert param_attr is not False, "param_attr should not be False here."
l_type = 'conv2d'
if (num_channels == groups and num_filters % num_channels == 0 and
not use_cudnn):
......@@ -1544,7 +1613,8 @@ def conv2d(input,
filter_shape = [num_filters, int(num_filter_channels)] + filter_size
def _get_default_param_initializer():
std = (2.0 / (filter_size[0]**2 * num_channels))**0.5
filter_elem_num = filter_size[0] * filter_size[1] * num_channels
std = (2.0 / filter_elem_num)**0.5
return Normal(0.0, std, 0)
filter_param = helper.create_parameter(
......@@ -1655,13 +1725,22 @@ def conv3d(input,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1
param_attr (ParamAttr): The parameters to the Conv3d Layer. Default: None
bias_attr (ParamAttr): Bias parameter for the Conv3d layer. Default: None
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of conv3d. If it is set to None or one attribute of ParamAttr, conv3d
will create ParamAttr as param_attr. If it is set to None, the parameter
is initialized with :math:`Normal(0.0, std)`, and the :math:`std` is
:math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv3d.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv3d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act (str): Activation type. Default: None
act (str): Activation type, if it is set to None, activation is not appended.
Default: None.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
will be named automatically. Default: None.
Returns:
Variable: The tensor variable storing the convolution and \
......@@ -1679,7 +1758,7 @@ def conv3d(input,
"""
l_type = 'conv3d'
assert param_attr is not False, "param_attr should not be False here."
helper = LayerHelper(l_type, **locals())
dtype = helper.input_dtype()
......@@ -1704,7 +1783,9 @@ def conv3d(input,
filter_shape = [num_filters, num_filter_channels] + filter_size
def _get_default_param_initializer():
std = (2.0 / (filter_size[0]**3 * num_channels))**0.5
filter_elem_num = filter_size[0] * filter_size[1] * filter_size[
2] * num_channels
std = (2.0 / filter_elem_num)**0.5
return Normal(0.0, std, 0)
filter_param = helper.create_parameter(
......@@ -2106,8 +2187,14 @@ def batch_norm(input,
is_test(bool, Default False): Used for training or training.
momentum(float, Default 0.9):
epsilon(float, Default 1e-05):
param_attr(ParamAttr): The parameter attribute for Parameter `scale`.
bias_attr(ParamAttr): The parameter attribute for Parameter `bias`.
param_attr(ParamAttr|None): The parameter attribute for Parameter `scale`
of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr(ParamAttr|None): The parameter attribute for the bias of batch_norm.
If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
data_layout(string, default NCHW): NCHW|NHWC
in_place(bool, Default False): Make the input and output of batch norm reuse memory.
name(string, Default None): A name for this layer(optional). If set None, the layer
......@@ -2127,6 +2214,7 @@ def batch_norm(input,
hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
hidden2 = fluid.layers.batch_norm(input=hidden1)
"""
assert bias_attr is not False, "bias_attr should not be False in batch_norm."
helper = LayerHelper('batch_norm', **locals())
dtype = helper.input_dtype()
......@@ -2243,19 +2331,28 @@ def layer_norm(input,
Args:
input(Variable): The input tensor variable.
scale(bool): Whether to learn the adaptive gain :math:`g` after
normalization.
normalization. Default True.
shift(bool): Whether to learn the adaptive bias :math:`b` after
normalization.
begin_norm_axis(bool): The normalization will be performed along
normalization. Default True.
begin_norm_axis(int): The normalization will be performed along
dimensions from :attr:`begin_norm_axis` to :attr:`rank(input)`.
Default 1.
epsilon(float): The small value added to the variance to prevent
division by zero.
division by zero. Default 1e-05.
param_attr(ParamAttr|None): The parameter attribute for the learnable
gain :math:`g`.
gain :math:`g`. If :attr:`scale` is False, :attr:`param_attr` is
omitted. If :attr:`scale` is True and :attr:`param_attr` is None,
a default :code:`ParamAttr` would be added as scale. The
:attr:`param_attr` is initialized as 1 if it is added. Default None.
bias_attr(ParamAttr|None): The parameter attribute for the learnable
bias :math:`b`.
bias :math:`b`. If :attr:`shift` is False, :attr:`bias_attr` is
omitted. If :attr:`shift` is True and :attr:`param_attr` is None,
a default :code:`ParamAttr` would be added as bias. The
:attr:`bias_attr` is initialized as 0 if it is added. Default None.
act(str): Activation to be applied to the output of layer normalizaiton.
name (str): The name of this layer. It is optional.
Default None.
name(str): The name of this layer. It is optional. Default None, and a
unique name would be generated automatically.
Returns:
${y_comment}
......@@ -2396,15 +2493,22 @@ def conv2d_transpose(input,
when group=2, the first half of the filters is only connected to the
first half of the input channels, while the second half of the
filters is only connected to the second half of the input channels.
Default: groups=1
param_attr(ParamAttr): The parameters to the Conv2d_transpose Layer.
Default: None
bias_attr(ParamAttr): Bias parameter for the Conv2d layer. Default: None
Default: groups = 1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d_transpose.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act(str): Activation type. Default: None
library is installed. Default: True.
act (str): Activation type, if it is set to None, activation is not appended.
Default: None.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
will be named automatically. Default: True.
Returns:
Variable: The tensor variable storing the convolution transpose result.
......@@ -2419,7 +2523,7 @@ def conv2d_transpose(input,
data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
conv2d_transpose = fluid.layers.conv2d_transpose(input=data, num_filters=2, filter_size=3)
"""
assert param_attr is not False, "param_attr should not be False in conv2d_transpose."
input_channel = input.shape[1]
op_type = 'conv2d_transpose'
......@@ -2455,6 +2559,7 @@ def conv2d_transpose(input,
else:
filter_size = utils.convert_to_list(filter_size, 2,
'conv2d_transpose.filter_size')
if output_size is None:
output_size = []
elif isinstance(output_size, list) or isinstance(output_size, int):
......@@ -2464,6 +2569,7 @@ def conv2d_transpose(input,
padding = utils.convert_to_list(padding, 2, 'padding')
groups = 1 if groups is None else groups
filter_shape = [input_channel, num_filters // groups] + filter_size
img_filter = helper.create_parameter(
dtype=input.dtype, shape=filter_shape, attr=helper.param_attr)
......@@ -2576,12 +2682,19 @@ def conv3d_transpose(input,
first half of the input channels, while the second half of the
filters is only connected to the second half of the input channels.
Default: groups=1
param_attr(ParamAttr): The parameters to the Conv3d_transpose Layer.
Default: None
bias_attr(ParamAttr): Bias parameter for the Conv3d layer. Default: None
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of conv3d_transpose. If it is set to None or one attribute of ParamAttr, conv3d_transpose
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv3d_transpose.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv3d_transpose
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act(str): Activation type. Default: None
act (str): Activation type, if it is set to None, activation is not appended.
Default: None.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
......@@ -2598,6 +2711,7 @@ def conv3d_transpose(input,
data = fluid.layers.data(name='data', shape=[3, 12, 32, 32], dtype='float32')
conv3d_transpose = fluid.layers.conv3d_transpose(input=data, num_filters=2, filter_size=3)
"""
assert param_attr is not False, "param_attr should not be False in conv3d_transpose."
l_type = "conv3d_transpose"
helper = LayerHelper(l_type, **locals())
if not isinstance(input, Variable):
......@@ -3054,10 +3168,18 @@ def lstm_unit(x_t,
cell_t_prev (Variable): The cell value of lstm unit, a 2-D tensor with
shape M x S, M for batch size and S for size of lstm unit.
forget_bias (float): The forget bias of lstm unit.
param_attr (ParamAttr): The attributes of parameter weights, used to set
initializer, name etc.
bias_attr (ParamAttr): The attributes of bias weights, if not False,
bias weights will be created and be set to default value.
param_attr(ParamAttr|None): The parameter attribute for the learnable
hidden-hidden weights.
If it is set to None or one attribute of ParamAttr,
lstm_unit will create ParamAttr as param_attr.
If the Initializer of the param_attr is not set, the
parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The bias attribute for the learnable bias
weights. If it is set to False, no bias will be added
to the output units. If it is set to None or one attribute of ParamAttr,
lstm_unit will create ParamAttr as bias_attr.
If the Initializer of the bias_attr is not set,
the bias is initialized zero. Default: None.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
......@@ -3971,7 +4093,8 @@ def nce(input,
sample_weight=None,
param_attr=None,
bias_attr=None,
num_neg_samples=None):
num_neg_samples=None,
name=None):
"""
${comment}
......@@ -3982,9 +4105,18 @@ def nce(input,
sample_weight (Variable|None): A Variable of shape [batch_size, 1]
storing a weight for each sample. The default weight for each
sample is 1.0.
param_attr (ParamAttr|None): attributes for parameter
bias_attr (ParamAttr|None): attributes for bias
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of nce. If it is set to None or one attribute of ParamAttr, nce
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of nce.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, nce
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
num_neg_samples (int): ${num_neg_samples_comment}
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None.
Returns:
Variable: The output nce loss.
......@@ -4017,19 +4149,28 @@ def nce(input,
"""
helper = LayerHelper('nce', **locals())
assert isinstance(input, Variable)
dim = input.shape[1]
assert isinstance(label, Variable)
dim = input.shape[1]
num_true_class = label.shape[1]
w = helper.create_parameter(
attr=helper.param_attr,
shape=[num_total_classes, dim],
is_bias=False,
dtype=input.dtype)
inputs = {
'Input': input,
'Label': label,
'Weight': w,
'SampleWeight': sample_weight if sample_weight is not None else []
}
if helper.bias_attr:
b = helper.create_parameter(
attr=helper.bias_attr,
shape=[num_total_classes, 1],
is_bias=True,
dtype=input.dtype)
inputs['Bias'] = b
cost = helper.create_tmp_variable(dtype=input.dtype)
sample_logits = helper.create_tmp_variable(dtype=input.dtype)
sample_labels = helper.create_tmp_variable(dtype=label.dtype)
......@@ -4046,13 +4187,7 @@ def nce(input,
helper.append_op(
type='nce',
inputs={
'Input': input,
'Label': label,
'Weight': w,
'Bias': b,
'SampleWeight': sample_weight if sample_weight is not None else []
},
inputs=inputs,
outputs={
'Cost': cost,
'SampleLogits': sample_logits,
......@@ -4062,7 +4197,12 @@ def nce(input,
return cost / (num_neg_samples + 1)
def hsigmoid(input, label, num_classes, param_attr=None, bias_attr=None):
def hsigmoid(input,
label,
num_classes,
param_attr=None,
bias_attr=None,
name=None):
"""
The hierarchical sigmoid operator is used to accelerate the training
process of language model. This operator organizes the classes into a
......@@ -4083,11 +4223,17 @@ def hsigmoid(input, label, num_classes, param_attr=None, bias_attr=None):
label (Variable): The tensor variable contains labels of training data.
It's a tensor with shape is :math:`[N \\times 1]`.
num_classes: (int), The number of classes, must not be less than 2.
param_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for learnable parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter
attribute for the bias of this layer. If it is set to False, no
bias will be applied.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of hsigmoid. If it is set to None or one attribute of ParamAttr, hsigmoid
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of hsigmoid.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, hsigmoid
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None.
Returns:
Out: (Tensor) The cost of hierarchical sigmoid operator. the shape is [N, 1]
......
......@@ -14,6 +14,8 @@
from __future__ import print_function
from .layer_function_generator import generate_layer_fn, generate_layer_fn_noattr
from .. import core
from ..framework import convert_np_dtype_to_dtype_
__activations_noattr__ = [
'sigmoid',
......@@ -58,8 +60,11 @@ _uniform_random_ = generate_layer_fn('uniform_random')
def uniform_random(shape, dtype=None, min=None, max=None, seed=None):
locals_var = locals().keys()
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
kwargs = dict()
for name in locals():
for name in locals_var:
val = locals()[name]
if val is not None:
kwargs[name] = val
......@@ -78,8 +83,9 @@ _hard_shrink_ = generate_layer_fn('hard_shrink')
def hard_shrink(x, threshold=None):
locals_var = locals().keys()
kwargs = dict()
for name in locals():
for name in locals_var:
val = locals()[name]
if val is not None:
kwargs[name] = val
......@@ -99,12 +105,12 @@ _cum_sum_ = generate_layer_fn('cumsum')
def cumsum(x, axis=None, exclusive=None, reverse=None):
locals_var = locals().keys()
kwargs = dict()
for name in locals():
for name in locals_var:
val = locals()[name]
if val is not None:
kwargs[name] = val
return _cum_sum_(**kwargs)
......@@ -121,8 +127,9 @@ _thresholded_relu_ = generate_layer_fn('thresholded_relu')
def thresholded_relu(x, threshold=None):
locals_var = locals().keys()
kwargs = dict()
for name in locals():
for name in locals_var:
val = locals()[name]
if val is not None:
kwargs[name] = val
......
......@@ -111,7 +111,7 @@ def create_global_var(shape,
force_cpu=False,
name=None):
"""
Create a new variable in the global block(block 0).
Create a new tensor variable with value in the global block(block 0).
Args:
shape(list[int]): shape of the variable
......
......@@ -64,23 +64,33 @@ def simple_img_conv_pool(input,
average-pooling. Default :math:`max`.
global_pooling (bool): Whether to use the global pooling. If global_pooling = true,
pool_size and pool_padding while be ignored. Default False
conv_stride (int|list|tuple): The stride size of the Conv2d Layer. If stride is a
conv_stride (int|list|tuple): The stride size of the conv2d Layer. If stride is a
list or tuple, it must contain two integers, (conv_stride_H, conv_stride_W). Otherwise,
the conv_stride_H = conv_stride_W = conv_stride. Default: conv_stride = 1.
conv_padding (int|list|tuple): The padding size of the Conv2d Layer. If padding is
conv_padding (int|list|tuple): The padding size of the conv2d Layer. If padding is
a list or tuple, it must contain two integers, (conv_padding_H, conv_padding_W).
Otherwise, the conv_padding_H = conv_padding_W = conv_padding. Default: conv_padding = 0.
conv_dilation (int|list|tuple): The dilation size of the Conv2d Layer. If dilation is
conv_dilation (int|list|tuple): The dilation size of the conv2d Layer. If dilation is
a list or tuple, it must contain two integers, (conv_dilation_H, conv_dilation_W).
Otherwise, the conv_dilation_H = conv_dilation_W = conv_dilation. Default: conv_dilation = 1.
conv_groups (int): The groups number of the Conv2d Layer. According to grouped
conv_groups (int): The groups number of the conv2d Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1
param_attr (ParamAttr): The parameters to the Conv2d Layer. Default: None
bias_attr (ParamAttr): Bias parameter for the Conv2d layer. Default: None
act (str): Activation type for Conv2d. Default: None
connected to the second half of the input channels. Default: groups=1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`.
Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
act (str): Activation type for conv2d, if it is set to None, activation is not
appended. Default: None.
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
......
......@@ -659,6 +659,9 @@ class AdamaxOptimizer(Optimizer):
optimizer = fluid.optimizer.Adamax(learning_rate=0.2)
optimizer.minimize(cost)
Notes:
Currently, AdamaxOptimizer doesn't support sparse gradient.
"""
_moment_acc_str = "moment"
_inf_norm_acc_str = "inf_norm"
......@@ -778,6 +781,9 @@ class DecayedAdagradOptimizer(Optimizer):
optimizer = fluid.optimizer.DecayedAdagrad(learning_rate=0.2)
optimizer.minimize(cost)
Notes:
Currently, DecayedAdagradOptimizer doesn't support sparse gradient.
"""
_moment_acc_str = "moment"
......@@ -858,6 +864,9 @@ class AdadeltaOptimizer(Optimizer):
optimizer = fluid.optimizer.Adadelta(
learning_rate=0.0003, epsilon=1.0e-6, rho=0.95)
_, params_grads = optimizer.minimize(cost)
Notes:
Currently, AdadeltaOptimizer doesn't support sparse gradient.
"""
_avg_squared_grad_acc_str = "_avg_squared_grad"
......@@ -1126,6 +1135,9 @@ class FtrlOptimizer(Optimizer):
optimizer = fluid.optimizer.Ftrl(0.0001)
_, params_grads = optimizer.minimize(cost)
Notes:
Currently, FtrlOptimizer doesn't support sparse gradient.
"""
_squared_acc_str = "squared"
......
......@@ -31,15 +31,32 @@ BuildStrategy = core.ParallelExecutor.BuildStrategy
class ParallelExecutor(object):
"""
ParallelExecutor can run program in parallel.
ParallelExecutor is designed for data parallelism, which focuses on distributing
the data across different nodes and every node operates on the data in parallel.
If you use ParallelExecutor to run the current program on GPU, the node means GPU
device, and ParallelExecutor will get the available GPU device automatically on
the current machine. If you use ParallelExecutor to run the current program on CPU,
the node means the CPU device, and you can specify the CPU device number by adding
'CPU_NUM' environment variable, for example 'CPU_NUM=4', if the environment variable
is not found, ParallelExecutor will call `multiprocessing.cpu_count` to get the number
of CPUs in the system.
Args:
use_cuda (bool): Whether to use CUDA or not.
loss_name (str): The loss name must set in training. Default None.
main_program (Program): The program that need to run, if not provided,
then default_main_program will be used. Default None.
share_vars_from(ParallelExecutor): If provied, it will share variables
share_vars_from(ParallelExecutor): If provide, it will share variables
from the specified ParallelExecutor. Default None.
exec_strategy(ExecutionStrategy): exec_strategy is used to control how to run
the program in ParallelExecutor, for example how many threads are used to
execute the program, how many iterations to clean up the temp variables
which is generated during execution. For more information, please refer
to fluid.ExecutionStrategy. Default None.
build_strategy(BuildStrategy): build_strategy is used to control how to
build the SSA Graph in ParallelExecutor by setting the property,
for example reduce_strategy, gradient_scale_strategy. For more information,
please refer to fluid.BuildStrategy. Default None.
num_trainers(int): If greater than 1, NCCL will be initialized with
multiple rank of nodes, each node should have same number of GPUs.
Distributed training will be enabled then. Default 1.
......
......@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test import OpTest
......@@ -88,5 +90,97 @@ class TestMomentumOp2(OpTest):
self.check_output()
class TestSparseMomentumOp(unittest.TestCase):
def setUp(self):
self.use_nesterov = False
def check_with_place(self, place):
self.init_kernel()
scope = core.Scope()
# create and initialize Grad Variable
height = 10
rows = [0, 4, 7]
row_numel = 12
mu = 1.0
use_nesterov = self.use_nesterov
# create and initialize Param Variable
param = scope.var('Param').get_tensor()
param_array = np.full((height, row_numel), 5.0).astype("float32")
param.set(param_array, place)
param_out = scope.var("ParamOut").get_tensor()
param_out_array = np.full((height, row_numel), 0.0).astype("float32")
param_out.set(param_out_array, place)
grad_selected_rows = scope.var('Grad').get_selected_rows()
grad_selected_rows.set_height(height)
grad_selected_rows.set_rows(rows)
grad_np_array = np.ones((len(rows), row_numel)).astype("float32")
grad_np_array[0, 0] = 2.0
grad_np_array[2, 8] = 4.0
grad_tensor = grad_selected_rows.get_tensor()
grad_tensor.set(grad_np_array, place)
velocity = scope.var('Velocity').get_tensor()
velocity_np_array = np.ones((height, row_numel)).astype("float32")
velocity.set(velocity_np_array, place)
velocity_out = scope.var('VelocityOut').get_tensor()
velocity_out_np_array = np.full((height, row_numel),
0.0).astype("float32")
velocity_out.set(velocity_out_np_array, place)
# create and initialize LeraningRate Variable
lr = scope.var('LearningRate').get_tensor()
lr_array = np.full((1), 2.0).astype("float32")
lr.set(lr_array, place)
# create and run operator
op = Operator(
"momentum",
Param='Param',
Grad='Grad',
Velocity='Velocity',
ParamOut='ParamOut',
VelocityOut='VelocityOut',
LearningRate='LearningRate',
mu=mu,
use_nesterov=use_nesterov)
op.run(scope, place)
# get and compare result
param_out_np_array = np.array(param_out)
velocity_out_np_array = np.array(velocity_out)
# TODO(dzh): add a more suitable general numpy interface
# for sparse update.
_grad_np_array = np.full((height, row_numel), 0.0).astype("float32")
for i in range(len(rows)):
_grad_np_array[rows[i]] = grad_np_array[i]
_velocity_out = mu * velocity_np_array + _grad_np_array
_param = param_array
if use_nesterov:
_param_out = _param - (_grad_np_array + _velocity_out * mu
) * lr_array
else:
_param_out = _param - lr_array * _velocity_out
self.assertTrue((_velocity_out == velocity_out_np_array).all())
self.assertTrue((_param_out == param_out_np_array).all())
def init_kernel(self):
pass
def test_sparse_momentum(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
class TestSparseMomentumOp2(TestSparseMomentumOp):
def init_kernel(self):
self.use_nesterov = True
if __name__ == "__main__":
unittest.main()
......@@ -19,33 +19,76 @@ import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
def create_selected_rows_and_tensor(scope, place, height, row_num,
embedding_size):
sr = scope.var("@selected_rows@").get_selected_rows()
tensor = scope.var("grad").get_tensor()
rows = np.random.random_integers(
low=0, high=height - 1, size=[row_num, ]).astype('int64')
sr_val = np.random.random(size=[row_num, embedding_size]).astype('float32')
sr.set_height(height)
sr.set_rows(rows)
sr.get_tensor().set(sr_val, place)
tensor_val = np.zeros(shape=[height, embedding_size], dtype='float32')
for i in range(row_num):
row = rows[i]
tensor_val[row, :] = tensor_val[row, :] + sr_val[i, :]
tensor.set(tensor_val, place)
return tensor_val, sr_val
class TestBase(unittest.TestCase):
def setup(self, centered, epsilon=1e-6):
def setup(self,
place,
is_sparse,
centered,
size,
row_num=None,
epsilon=1e-6):
np.random.seed(5) # fix seed
self.scope = fluid.global_scope()
self.place = place
self.param_name = "param"
self.param = np.random.random((123, 321)).astype("float32")
self.param = np.random.random(size).astype("float32")
self.mean_square_name = "mean_square"
self.mean_square = np.random.random((123, 321)).astype("float32")
self.mean_square = np.random.uniform(
low=1, high=2, size=size).astype("float32")
self.mean_grad_name = "mean_grad"
self.mean_grad = np.random.random((123, 321)).astype("float32")
self.mean_grad = np.random.random(size).astype("float32")
self.lr_name = "lr"
self.learning_rate = np.array([0.01]).astype("float32")
self.grad_name = "grad"
self.grad = np.random.random((123, 321)).astype("float32")
self.is_sparse = is_sparse
if self.is_sparse:
self.grad_sr_name = "@selected_rows@"
self.grad, self.grad_sr = create_selected_rows_and_tensor(
self.scope, place, size[0], row_num, size[1])
else:
self.grad = np.random.random(size).astype("float32")
grad_tensor = self.scope.var(self.grad_name).get_tensor()
grad_tensor.set(self.grad, place)
self.moment_name = "moment"
self.moment = np.zeros((123, 321)).astype("float32")
self.moment = np.random.uniform(
low=0, high=1, size=size).astype("float32")
self.epsilon = epsilon
self.decay = 0.9
self.momentum = 0.0
self.momentum = 0.1
self.centered = centered
self.ms_out = self.decay * self.mean_square + (1 - self.decay
......@@ -61,118 +104,122 @@ class TestBase(unittest.TestCase):
self.param_out = self.param - self.moment_out
def check(self,
actual_t,
expect_t,
place,
out_name,
atol=1e-5,
equal_nan=False):
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol, equal_nan=equal_nan),
"Output (" + out_name + ") has diff at " + str(place) + "\nExpect "
+ str(expect_t) + "\n" + "But Got" + str(actual_t))
class TestRmspropOp(TestBase):
def check_with_place(self, place, centered, epsilon):
self.setup(centered, epsilon)
scope = core.Scope()
# create and initialize Param Variable
param = scope.var(self.param_name).get_tensor()
param.set(self.param, place)
self.param_tensor = self.scope.var(self.param_name).get_tensor()
self.param_tensor.set(self.param, place)
mean_square = scope.var(self.mean_square_name).get_tensor()
mean_square.set(self.mean_square, place)
self.mean_square_tensor = self.scope.var(
self.mean_square_name).get_tensor()
self.mean_square_tensor.set(self.mean_square, place)
lr = scope.var(self.lr_name).get_tensor()
lr = self.scope.var(self.lr_name).get_tensor()
lr.set(self.learning_rate, place)
grad = scope.var(self.grad_name).get_tensor()
grad.set(self.grad, place)
self.moment_tensor = self.scope.var(self.moment_name).get_tensor()
self.moment_tensor.set(self.moment, place)
moment = scope.var(self.moment_name).get_tensor()
moment.set(self.moment, place)
if self.centered:
self.mean_grad_tensor = self.scope.var(
self.mean_grad_name).get_tensor()
self.mean_grad_tensor.set(self.mean_grad, place)
def check(self, actual_t, expect_t, place, out_name, atol=1e-5):
self.assertTrue(
np.allclose(
actual_t, expect_t, atol=atol),
"Output (" + out_name + ") has diff at " + str(place) + "\nExpect "
+ str(expect_t) + "\n" + "But Got" + str(actual_t))
# create and run sgd operator
if self.centered:
mean_grad = scope.var(self.mean_grad_name).get_tensor()
mean_grad.set(self.mean_grad, place)
rmsprop_op = Operator(
"rmsprop",
Param=self.param_name,
Grad=self.grad_name,
MeanSquare=self.mean_square_name,
MeanGrad=self.mean_grad_name,
Moment=self.moment_name,
LearningRate=self.lr_name,
ParamOut=self.param_name,
MeanSquareOut=self.mean_square_name,
MomentOut=self.moment_name,
MeanGradOut=self.mean_grad_name,
epsilon=self.epsilon,
decay=self.decay,
momentum=self.momentum,
centered=True)
else:
rmsprop_op = Operator(
"rmsprop",
Param=self.param_name,
Grad=self.grad_name,
MeanSquare=self.mean_square_name,
Moment=self.moment_name,
LearningRate=self.lr_name,
ParamOut=self.param_name,
MeanSquareOut=self.mean_square_name,
MomentOut=self.moment_name,
epsilon=self.epsilon,
decay=self.decay,
momentum=self.momentum,
centered=False)
rmsprop_op.run(scope, place)
atol = 1e-5
equal_nan = False
class TestRmspropOp(TestBase):
def check_with_place(self,
place,
is_sparse,
centered,
size,
row_num=None,
epsilon=1e-6):
self.setup(place, is_sparse, centered, size, row_num, epsilon)
self.run_and_check()
def run_and_check(self):
grad_name = self.grad_sr_name if self.is_sparse else self.grad_name
kwargs = {
'Param': self.param_name,
'Grad': grad_name,
'MeanSquare': self.mean_square_name,
'Moment': self.moment_name,
'LearningRate': self.lr_name,
'ParamOut': self.param_name,
'MeanSquareOut': self.mean_square_name,
'MomentOut': self.moment_name,
'epsilon': self.epsilon,
'decay': self.decay,
'momentum': self.momentum,
'centered': self.centered
}
if self.centered:
atol = 1e-3
equal_nan = True
kwargs['MeanGrad'] = self.mean_grad_name
kwargs['MeanGradOut'] = self.mean_grad_name
rmsprop_op = Operator('rmsprop', **kwargs)
atol = 1e-6
rmsprop_op.run(self.scope, self.place)
self.check(
np.array(mean_square), self.ms_out, place, self.mean_square_name)
np.array(self.mean_square_tensor),
self.ms_out,
self.place,
self.mean_square_name,
atol=atol)
self.check(
np.array(moment),
np.array(self.moment_tensor),
self.moment_out,
place,
self.place,
self.moment_name,
atol=atol,
equal_nan=equal_nan)
atol=atol)
self.check(
np.array(param),
np.array(self.param_tensor),
self.param_out,
place,
self.place,
self.param_name,
atol=atol,
equal_nan=equal_nan)
atol=atol)
if self.centered:
self.check(
np.array(mean_grad), self.mg_out, place, self.mean_grad_name)
np.array(self.mean_grad_tensor), self.mg_out, self.place,
self.mean_grad_name)
def test_rmsprop(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
size = (128, 320)
for place in places:
self.check_with_place(place, False, 1e-6)
self.check_with_place(place, False, 1e-10)
self.check_with_place(place, True, 1e-6)
self.check_with_place(place, True, 1e-10)
for centered in [False, True]:
with fluid.scope_guard(core.Scope()):
self.check_with_place(
place, is_sparse=False, centered=centered, size=size)
with fluid.scope_guard(core.Scope()):
self.check_with_place(
place,
is_sparse=True,
centered=centered,
row_num=512,
size=size)
with fluid.scope_guard(core.Scope()):
self.check_with_place(
place,
is_sparse=True,
centered=centered,
row_num=60,
size=size)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册