提交 b139b687 编写于 作者: T tensor-tang

Merge remote-tracking branch 'ups/develop' into fix/jit/exp

test=develop
......@@ -127,6 +127,9 @@ set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
set(FLUID_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_install_dir" CACHE STRING
"A path setting fluid shared and static libraries")
set(FLUID_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/fluid_inference_install_dir" CACHE STRING
"A path setting fluid inference shared and static libraries")
if (WITH_C_API AND WITH_PYTHON)
message(WARNING "It is suggest not embedded a python interpreter in Paddle "
"when using C-API. It will give an unpredictable behavior when using a "
......
......@@ -19,7 +19,7 @@ Our vision is to enable deep learning for everyone via PaddlePaddle.
Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddle/releases) to track the latest feature of PaddlePaddle.
### Latest PaddlePaddle Release: [Fluid 0.15.0](https://github.com/PaddlePaddle/Paddle/tree/v0.15.0)
### Latest PaddlePaddle Release: [Fluid 1.0.0](https://github.com/PaddlePaddle/Paddle/tree/release/1.0.0)
### Install Latest Stable Release:
```
# Linux CPU
......@@ -76,26 +76,26 @@ pip install paddlepaddle-gpu==0.15.0.post85
## Installation
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/beginners_guide/install/install_doc.html) on our website.
It is recommended to read [this doc](http://paddlepaddle.org/documentation/docs/zh/1.0/beginners_guide/index.html) on our website.
## Documentation
We provide [English](http://paddlepaddle.org/documentation/docs/en/0.15.0/getstarted/index_en.html) and
[Chinese](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/beginners_guide/index.html) documentation.
We provide [English](http://paddlepaddle.org/documentation/docs/en/1.0.0/getstarted/index_en.html) and
[Chinese](http://paddlepaddle.org/documentation/docs/zh/1.0/beginners_guide/index.html) documentation.
- [Deep Learning 101](https://github.com/PaddlePaddle/book)
You might want to start from this online interactive book that can run in a Jupyter Notebook.
- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/user_guides/howto/training/cluster_howto.html)
- [Distributed Training](http://paddlepaddle.org/documentation/docs/zh/1.0/user_guides/howto/training/cluster_howto.html)
You can run distributed training jobs on MPI clusters.
- [Python API](http://paddlepaddle.org/documentation/api/zh/0.15.0/fluid.html)
- [Python API](http://paddlepaddle.org/documentation/api/zh/1.0/fluid.html)
Our new API enables much shorter programs.
- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/0.15.0/new_docs/advanced_usage/development/contribute_to_paddle.html)
- [How to Contribute](http://paddlepaddle.org/documentation/docs/zh/1.0/advanced_usage/development/contribute_to_paddle.html)
We appreciate your contributions!
......
文件模式从 100644 更改为 100755
......@@ -150,16 +150,16 @@ if (WITH_ANAKIN AND WITH_MKL)
SRCS
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/libinference_anakin_api* # compiled anakin api
${ANAKIN_INSTALL_DIR} # anakin release
DSTS ${dst_dir}/inference/anakin ${FLUID_INSTALL_DIR}/third_party/install/anakin)
DSTS ${FLUID_INSTALL_DIR}/third_party/install/anakin ${FLUID_INSTALL_DIR}/third_party/install/anakin)
list(APPEND inference_deps anakin_inference_lib)
endif()
set(module "inference")
copy(inference_lib DEPS ${inference_deps}
SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.*
${src_dir}/${module}/api/paddle_inference_api.h ${src_dir}/${module}/api/demo_ci
${src_dir}/${module}/api/paddle_inference_api.h
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
)
set(module "platform")
......@@ -188,18 +188,38 @@ copy(cmake_cache
# This command generates a complete fluid library for both train and inference
add_custom_target(fluid_lib_dist DEPENDS ${fluid_lib_dist_dep})
# Following commands generate a inference-only fluid library
# third_party, version.txt and CMakeCache.txt are the same position with ${FLUID_INSTALL_DIR}
copy(third_party DEPS fluid_lib_dist
SRCS ${FLUID_INSTALL_DIR}/third_party ${FLUID_INSTALL_DIR}/CMakeCache.txt
DSTS ${FLUID_INFERENCE_INSTALL_DIR} ${FLUID_INFERENCE_INSTALL_DIR}
)
# only need libpaddle_fluid.so/a and paddle_inference_api.h for inference-only library
copy(inference_api_lib DEPS fluid_lib_dist
SRCS ${FLUID_INSTALL_DIR}/paddle/fluid/inference/libpaddle_fluid.*
${FLUID_INSTALL_DIR}/paddle/fluid/inference/paddle_inference_api.h
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/lib ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include
)
add_custom_target(inference_lib_dist DEPENDS third_party inference_api_lib)
# paddle fluid version
execute_process(
COMMAND ${GIT_EXECUTABLE} log --pretty=format:%H -1
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}
OUTPUT_VARIABLE PADDLE_GIT_COMMIT)
set(version_file ${FLUID_INSTALL_DIR}/version.txt)
file(WRITE ${version_file}
"GIT COMMIT ID: ${PADDLE_GIT_COMMIT}\n"
"WITH_MKL: ${WITH_MKL}\n"
"WITH_GPU: ${WITH_GPU}\n")
if(WITH_GPU)
file(APPEND ${version_file}
"CUDA version: ${CUDA_VERSION}\n"
"CUDNN version: v${CUDNN_MAJOR_VERSION}\n")
endif()
function(version version_file)
execute_process(
COMMAND ${GIT_EXECUTABLE} log --pretty=format:%H -1
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}
OUTPUT_VARIABLE PADDLE_GIT_COMMIT)
file(WRITE ${version_file}
"GIT COMMIT ID: ${PADDLE_GIT_COMMIT}\n"
"WITH_MKL: ${WITH_MKL}\n"
"WITH_MKLDNN: ${WITH_MKLDNN}\n"
"WITH_GPU: ${WITH_GPU}\n")
if(WITH_GPU)
file(APPEND ${version_file}
"CUDA version: ${CUDA_VERSION}\n"
"CUDNN version: v${CUDNN_MAJOR_VERSION}\n")
endif()
endfunction()
version(${FLUID_INSTALL_DIR}/version.txt)
version(${FLUID_INFERENCE_INSTALL_DIR}/version.txt)
......@@ -173,6 +173,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
......
......@@ -100,16 +100,6 @@ class OpDesc {
std::vector<std::string> InputNames() const { return MapKeys(inputs_); }
std::vector<std::string> OutputNames() const { return MapKeys(outputs_); }
void SetInputMap(const VariableNameMap &input) {
this->inputs_ = input;
this->need_update_ = true;
}
void SetOutputMap(const VariableNameMap &output) {
this->outputs_ = output;
this->need_update_ = true;
}
const VariableNameMap &Inputs() const { return inputs_; }
const VariableNameMap &Outputs() const { return outputs_; }
......
......@@ -51,9 +51,7 @@ void TestWord2vecPrediction(const std::string& model_path) {
config.model_dir = model_path;
config.use_gpu = false;
config.device = 0;
auto predictor =
::paddle::CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(
config);
auto predictor = ::paddle::CreatePaddlePredictor<NativeConfig>(config);
// One single batch
......
......@@ -27,9 +27,7 @@ TEST(AnalysisPredictor, ZeroCopy) {
config.model_dir = FLAGS_dirname + "/word2vec.inference.model";
config.use_feed_fetch_ops = false;
auto predictor =
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
config);
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
auto w0 = predictor->GetInputTensor("firstw");
auto w1 = predictor->GetInputTensor("secondw");
......
......@@ -41,11 +41,8 @@ void CompareTensorRTWithFluid(bool enable_tensorrt) {
config1.device = 0;
config1.max_batch_size = 10;
auto predictor0 =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config0);
auto predictor1 =
CreatePaddlePredictor<MixedRTConfig,
PaddleEngineKind::kAutoMixedTensorRT>(config1);
auto predictor0 = CreatePaddlePredictor<NativeConfig>(config0);
auto predictor1 = CreatePaddlePredictor<MixedRTConfig>(config1);
for (int batch_id = 0; batch_id < 1; batch_id++) {
//# 2. Prepare input.
......
......@@ -77,7 +77,7 @@ endif(NOT WIN32)
link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib")
link_directories("${PADDLE_LIB}/third_party/install/glog/lib")
link_directories("${PADDLE_LIB}/third_party/install/gflags/lib")
link_directories("${PADDLE_LIB}/paddle/fluid/inference")
link_directories("${PADDLE_LIB}/paddle/lib")
add_executable(${DEMO_NAME} ${DEMO_NAME}.cc)
......@@ -97,10 +97,10 @@ endif()
# Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a
if(WITH_STATIC_LIB)
set(DEPS
${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX})
${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_STATIC_LIBRARY_SUFFIX})
else()
set(DEPS
${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX})
${PADDLE_LIB}/paddle/lib/libpaddle_fluid${CMAKE_SHARED_LIBRARY_SUFFIX})
endif()
if (NOT WIN32)
......
......@@ -5,12 +5,13 @@ TEST_GPU_CPU=$3 # test both GPU/CPU mode or only CPU mode
DATA_DIR=$4 # dataset
TENSORRT_INCLUDE_DIR=$5 # TensorRT header file dir, defalut to /usr/local/TensorRT/include
TENSORRT_LIB_DIR=$6 # TensorRT lib file dir, default to /usr/local/TensorRT/lib
inference_install_dir=${PADDLE_ROOT}/build/fluid_inference_install_dir
cd `dirname $0`
current_dir=`pwd`
if [ $2 == ON ]; then
# You can export yourself if move the install path
MKL_LIB=${PADDLE_ROOT}/build/fluid_install_dir/third_party/install/mklml/lib
MKL_LIB=${inference_install_dir}/third_party/install/mklml/lib
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${MKL_LIB}
fi
if [ $3 == ON ]; then
......@@ -55,7 +56,7 @@ cd build
for WITH_STATIC_LIB in ON OFF; do
# -----simple_on_word2vec-----
rm -rf *
cmake .. -DPADDLE_LIB=${PADDLE_ROOT}/build/fluid_install_dir/ \
cmake .. -DPADDLE_LIB=${inference_install_dir} \
-DWITH_MKL=$TURN_ON_MKL \
-DDEMO_NAME=simple_on_word2vec \
-DWITH_GPU=$TEST_GPU_CPU \
......@@ -75,7 +76,7 @@ for WITH_STATIC_LIB in ON OFF; do
fi
# ---------vis_demo---------
rm -rf *
cmake .. -DPADDLE_LIB=${PADDLE_ROOT}/build/fluid_install_dir/ \
cmake .. -DPADDLE_LIB=${inference_install_dir} \
-DWITH_MKL=$TURN_ON_MKL \
-DDEMO_NAME=vis_demo \
-DWITH_GPU=$TEST_GPU_CPU \
......@@ -98,7 +99,7 @@ for WITH_STATIC_LIB in ON OFF; do
# --------tensorrt mobilenet------
if [ $USE_TENSORRT == ON -a $TEST_GPU_CPU == ON ]; then
rm -rf *
cmake .. -DPADDLE_LIB=${PADDLE_ROOT}/build/fluid_install_dir/ \
cmake .. -DPADDLE_LIB=${inference_install_dir} \
-DWITH_MKL=$TURN_ON_MKL \
-DDEMO_NAME=trt_mobilenet_demo \
-DWITH_GPU=$TEST_GPU_CPU \
......
......@@ -23,7 +23,7 @@ limitations under the License. */
#include <memory>
#include <thread> //NOLINT
#include "paddle/fluid/inference/paddle_inference_api.h"
#include "paddle/include/paddle_inference_api.h"
DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_bool(use_gpu, false, "Whether use gpu.");
......@@ -42,8 +42,7 @@ void Main(bool use_gpu) {
config.use_gpu = use_gpu;
config.fraction_of_gpu_memory = 0.15;
config.device = 0;
auto predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
auto predictor = CreatePaddlePredictor<NativeConfig>(config);
for (int batch_id = 0; batch_id < 3; batch_id++) {
//# 2. Prepare input.
......@@ -85,8 +84,7 @@ void MainThreads(int num_threads, bool use_gpu) {
config.use_gpu = use_gpu;
config.fraction_of_gpu_memory = 0.15;
config.device = 0;
auto main_predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
auto main_predictor = CreatePaddlePredictor<NativeConfig>(config);
std::vector<std::thread> threads;
for (int tid = 0; tid < num_threads; ++tid) {
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h> // use glog instead of CHECK to avoid importing other paddle header files.
#include "paddle/fluid/inference/demo_ci/utils.h"
#include "utils.h" // NOLINT
DECLARE_double(fraction_of_gpu_memory_to_use);
DEFINE_string(modeldir, "", "Directory of the inference model.");
......
......@@ -18,7 +18,7 @@
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/inference/paddle_inference_api.h"
#include "paddle/include/paddle_inference_api.h"
namespace paddle {
namespace demo {
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h> // use glog instead of CHECK to avoid importing other paddle header files.
#include "paddle/fluid/inference/demo_ci/utils.h"
#include "utils.h" // NOLINT
#ifdef PADDLE_WITH_CUDA
DECLARE_double(fraction_of_gpu_memory_to_use);
......@@ -34,12 +34,13 @@ DEFINE_bool(use_gpu, false, "Whether use gpu.");
namespace paddle {
namespace demo {
using contrib::AnalysisConfig;
/*
* Use the native fluid engine to inference the demo.
* Use the native and analysis fluid engine to inference the demo.
*/
void Main(bool use_gpu) {
std::unique_ptr<PaddlePredictor> predictor;
NativeConfig config;
std::unique_ptr<PaddlePredictor> predictor, analysis_predictor;
AnalysisConfig config;
config.param_file = FLAGS_modeldir + "/__params__";
config.prog_file = FLAGS_modeldir + "/__model__";
config.use_gpu = use_gpu;
......@@ -49,8 +50,8 @@ void Main(bool use_gpu) {
}
VLOG(3) << "init predictor";
predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
predictor = CreatePaddlePredictor<NativeConfig>(config);
analysis_predictor = CreatePaddlePredictor<AnalysisConfig>(config);
VLOG(3) << "begin to process data";
// Just a single batch of data.
......@@ -68,7 +69,7 @@ void Main(bool use_gpu) {
input.dtype = PaddleDType::FLOAT32;
VLOG(3) << "run executor";
std::vector<PaddleTensor> output;
std::vector<PaddleTensor> output, analysis_output;
predictor->Run({input}, &output, 1);
VLOG(3) << "output.size " << output.size();
......@@ -77,6 +78,10 @@ void Main(bool use_gpu) {
// compare with reference result
CheckOutput(FLAGS_refer, tensor);
// the analysis_output has some diff with native_output,
// TODO(luotao): add CheckOutput for analysis_output later.
analysis_predictor->Run({input}, &analysis_output, 1);
}
} // namespace demo
......
......@@ -308,18 +308,13 @@ TEST(Analyzer_rnn1, ZeroCopy) {
PaddlePlace place;
int output_size{0};
auto predictor =
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
config);
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
config.use_feed_fetch_ops = true;
auto native_predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
auto native_predictor = CreatePaddlePredictor<NativeConfig>(config);
config.use_feed_fetch_ops = true; // the analysis predictor needs feed/fetch.
auto analysis_predictor =
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
config);
auto analysis_predictor = CreatePaddlePredictor<AnalysisConfig>(config);
#define NEW_TENSOR(name__) \
auto name__##_tensor = predictor->GetInputTensor(#name__);
......
......@@ -77,11 +77,9 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
std::unique_ptr<PaddlePredictor> CreateTestPredictor(
const AnalysisConfig &config, bool use_analysis = true) {
if (use_analysis) {
return CreatePaddlePredictor<contrib::AnalysisConfig,
PaddleEngineKind::kAnalysis>(config);
return CreatePaddlePredictor<contrib::AnalysisConfig>(config);
} else {
return CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(
config);
return CreatePaddlePredictor<NativeConfig>(config);
}
}
......
......@@ -51,11 +51,8 @@ void CompareTensorRTWithFluid(int batch_size, std::string model_dirname) {
config1.model_dir = model_dirname;
config1.max_batch_size = batch_size;
auto predictor0 =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config0);
auto predictor1 =
CreatePaddlePredictor<MixedRTConfig,
PaddleEngineKind::kAutoMixedTensorRT>(config1);
auto predictor0 = CreatePaddlePredictor<NativeConfig>(config0);
auto predictor1 = CreatePaddlePredictor<MixedRTConfig>(config1);
// Prepare inputs
int height = 224;
int width = 224;
......
......@@ -305,6 +305,7 @@ if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col)
op_library(layer_norm_op DEPS cub)
op_library(reduce_mean_op DEPS cub)
op_library(affine_channel_op DEPS cub)
else()
op_library(conv_op DEPS vol2col im2col)
endif()
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class AffineChannelOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) Feature map input can be a 4D tensor with order NCHW "
"or NHWC. It also can be a 2D tensor and C is the second "
"dimension.");
AddInput("Scale",
"(Tensor) 1D input of shape (C), the c-th element "
"is the scale factor of the affine transformation "
"for the c-th channel of the input.");
AddInput("Bias",
"(Tensor) 1D input of shape (C), the c-th element "
"is the bias of the affine transformation for the "
"c-th channel of the input.");
AddAttr<std::string>(
"data_layout",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddOutput("Out", "(Tensor) A tensor of the same shape and order with X.");
AddComment(R"DOC(
Applies a separate affine transformation to each channel of the input. Useful
for replacing spatial batch norm with its equivalent fixed transformation.
The input also can be 2D tensor and applies a affine transformation in second
dimension.
$$Out = Scale*X + Bias$$
)DOC");
}
};
class AffineChannelOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of AffineChannelOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) of AffineChannelOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of AffineChannelOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of AffineChannelOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", "Out");
}
};
class AffineChannelOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
if (ctx->HasOutput(framework::GradVarName("X"))) {
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"),
ctx->GetInputDim(framework::GradVarName("Out")));
}
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
// Scale@GRAD and Bias@GRAD must exist at the same time.
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
"Output(Scale@GRAD) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
ctx->SetOutputDim(framework::GradVarName("Scale"),
ctx->GetInputDim("Scale"));
ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Scale"));
}
}
};
template <typename T>
using EigenArrayMap =
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using ConstEigenArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename T>
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename DeviceContext, typename T>
class AffineChannelKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X");
auto* scale = ctx.Input<framework::Tensor>("Scale");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* y = ctx.Output<framework::Tensor>("Out");
y->mutable_data<T>(ctx.GetPlace());
const framework::DataLayout layout =
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
auto dims = x->dims();
int N = dims[0];
int C = layout == framework::DataLayout::kNCHW ? dims[1]
: dims[dims.size() - 1];
int HxW = x->numel() / N / C;
auto* scale_d = scale->data<T>();
auto* bias_d = bias->data<T>();
ConstEigenVectorArrayMap<T> a_e(scale_d, C);
ConstEigenVectorArrayMap<T> b_e(bias_d, C);
auto* x_d = x->data<T>();
auto* y_d = y->data<T>();
if (layout == framework::DataLayout::kNCHW) {
int stride = C * HxW;
for (int i = 0; i < N; i++) {
ConstEigenArrayMap<T> x_e(x_d, HxW, C);
EigenArrayMap<T> y_e(y_d, HxW, C);
y_e = (x_e.rowwise() * a_e.transpose()).rowwise() + b_e.transpose();
x_d += stride;
y_d += stride;
}
} else {
int num = N * HxW;
ConstEigenArrayMap<T> x_e(x_d, C, num);
EigenArrayMap<T> y_e(y_d, C, num);
y_e = (x_e.colwise() * a_e).colwise() + b_e;
}
}
};
template <typename DeviceContext, typename T>
class AffineChannelGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X");
auto* scale = ctx.Input<framework::Tensor>("Scale");
auto* dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dscale =
ctx.Output<framework::Tensor>(framework::GradVarName("Scale"));
auto* dbias = ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
const framework::DataLayout layout =
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
auto dims = x->dims();
int N = dims[0];
int C = layout == framework::DataLayout::kNCHW ? dims[1]
: dims[dims.size() - 1];
int HxW = x->numel() / N / C;
auto* x_d = x->data<T>();
auto* dy_d = dy->data<T>();
auto* scale_d = scale->data<T>();
ConstEigenVectorArrayMap<T> scale_e(scale_d, C);
T* dx_d = dx ? dx->mutable_data<T>(ctx.GetPlace()) : nullptr;
T* dscale_d = dscale ? dscale->mutable_data<T>(ctx.GetPlace()) : nullptr;
T* dbias_d = dbias ? dbias->mutable_data<T>(ctx.GetPlace()) : nullptr;
EigenVectorArrayMap<T> dscale_e(dscale_d, C);
EigenVectorArrayMap<T> dbias_e(dbias_d, C);
if (layout == framework::DataLayout::kNCHW) {
// compute dx
int stride = C * HxW;
if (dx) {
for (int i = 0; i < N; i++) {
ConstEigenArrayMap<T> dy_e(dy_d, HxW, C);
EigenArrayMap<T> dx_e(dx_d, HxW, C);
dx_e = dy_e.rowwise() * scale_e.transpose();
dy_d += stride;
dx_d += stride;
}
}
// compute dscale and dbias
if (dscale && dbias) {
dy_d = dy->data<T>();
for (int i = 0; i < N; i++) {
ConstEigenArrayMap<T> x_e(x_d, HxW, C);
ConstEigenArrayMap<T> dy_e(dy_d, HxW, C);
if (i == 0) {
dscale_e = (x_e * dy_e).colwise().sum();
} else {
dscale_e += (x_e * dy_e).colwise().sum();
}
if (i == 0) {
dbias_e = dy_e.colwise().sum();
} else {
dbias_e += dy_e.colwise().sum();
}
x_d += stride;
dy_d += stride;
}
}
} else {
int num = N * HxW;
ConstEigenArrayMap<T> dy_e(dy_d, C, num);
// compute dx
if (dx) {
EigenArrayMap<T> dx_e(dx_d, C, num);
dx_e = dy_e.colwise() * scale_e;
}
// compute dscale and dbias
if (dscale && dbias) {
ConstEigenArrayMap<T> x_e(x_d, C, num);
dscale_e = (x_e * dy_e).rowwise().sum();
dbias_e = dy_e.rowwise().sum();
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(affine_channel, ops::AffineChannelOp,
ops::AffineChannelOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelOpGrad);
REGISTER_OP_CPU_KERNEL(affine_channel, ops::AffineChannelKernel<CPU, float>,
ops::AffineChannelKernel<CPU, double>);
REGISTER_OP_CPU_KERNEL(affine_channel_grad,
ops::AffineChannelGradKernel<CPU, float>,
ops::AffineChannelGradKernel<CPU, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cub/cub.cuh"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
template <typename T, framework::DataLayout layout, bool HasBias>
__global__ void KeAffineChannelCUDA(const T* x, const T* scale, const T* bias,
const int C, const int HxW, const int num,
T* y) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = gid; i < num; i += stride) {
const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
if (HasBias) {
y[i] = scale[c] * x[i] + bias[c];
} else {
y[i] = scale[c] * x[i];
}
}
}
template <typename DeviceContext, typename T>
class AffineChannelCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X");
auto* scale = ctx.Input<framework::Tensor>("Scale");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* y = ctx.Output<framework::Tensor>("Out");
y->mutable_data<T>(ctx.GetPlace());
const framework::DataLayout layout =
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto dims = x->dims();
const int num = x->numel();
int N = dims[0];
int C = layout == framework::DataLayout::kNCHW ? dims[1]
: dims[dims.size() - 1];
int HxW = num / N / C;
const T* x_d = x->data<T>();
const T* scale_d = scale->data<T>();
const T* bias_d = bias->data<T>();
T* y_d = y->data<T>();
int block = 1024;
int grid = (num + block - 1) / block;
if (layout == framework::DataLayout::kNCHW) {
KeAffineChannelCUDA<T, framework::DataLayout::kNCHW,
true><<<grid, block, 0, dev_ctx.stream()>>>(
x_d, scale_d, bias_d, C, HxW, num, y_d);
} else {
KeAffineChannelCUDA<T, framework::DataLayout::kNHWC,
true><<<grid, block, 0, dev_ctx.stream()>>>(
x_d, scale_d, bias_d, C, HxW, num, y_d);
}
}
};
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void AffineChannelScaleBiasGradientCUDAKernel(
const T* dy, const T* x, const int N, const int C, const int HxW, T* dscale,
T* dbias) {
const int outer_size = C;
const int inner_size = N * HxW;
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ds_storage;
__shared__ typename BlockReduce::TempStorage db_storage;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
T ds_sum = 0;
T db_sum = 0;
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
ds_sum += dy[index] * x[index];
db_sum += dy[index];
}
ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
if (threadIdx.x == 0) {
dscale[i] = ds_sum;
dbias[i] = db_sum;
}
__syncthreads();
}
}
template <typename DeviceContext, typename T>
class AffineChannelGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X");
auto* scale = ctx.Input<framework::Tensor>("Scale");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dscale =
ctx.Output<framework::Tensor>(framework::GradVarName("Scale"));
auto* dbias = ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
const framework::DataLayout layout =
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto dims = x->dims();
const int num = x->numel();
int N = dims[0];
int C = layout == framework::DataLayout::kNCHW ? dims[1]
: dims[dims.size() - 1];
int HxW = num / N / C;
const T* x_d = x->data<T>();
const T* dy_d = dy->data<T>();
const T* s_d = scale->data<T>();
T* dx_d = dx ? dx->mutable_data<T>(ctx.GetPlace()) : nullptr;
T* ds_d = dscale ? dscale->mutable_data<T>(ctx.GetPlace()) : nullptr;
T* db_d = dbias ? dbias->mutable_data<T>(ctx.GetPlace()) : nullptr;
const int block = 1024;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid1 = (num + block - 1) / block;
int grid2 = std::min(C, max_blocks);
if (layout == framework::DataLayout::kNCHW) {
if (dx) {
KeAffineChannelCUDA<T, framework::DataLayout::kNCHW,
false><<<grid1, block, 0, dev_ctx.stream()>>>(
dy_d, s_d, nullptr, C, HxW, num, dx_d);
}
if (dscale && dbias) {
AffineChannelScaleBiasGradientCUDAKernel<
T, block, framework::DataLayout::kNCHW><<<grid2, block, 0,
dev_ctx.stream()>>>(
dy_d, x_d, N, C, HxW, ds_d, db_d);
}
} else {
if (dx) {
KeAffineChannelCUDA<T, framework::DataLayout::kNCHW,
false><<<grid1, block, 0, dev_ctx.stream()>>>(
dy_d, s_d, nullptr, C, HxW, num, dx_d);
}
if (dscale && dbias) {
AffineChannelScaleBiasGradientCUDAKernel<
T, block, framework::DataLayout::kNHWC><<<grid2, block, 0,
dev_ctx.stream()>>>(
dy_d, x_d, N, C, HxW, ds_d, db_d);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(affine_channel,
ops::AffineChannelCUDAKernel<CUDA, float>,
ops::AffineChannelCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(affine_channel_grad,
ops::AffineChannelGradCUDAKernel<CUDA, float>,
ops::AffineChannelGradCUDAKernel<CUDA, double>);
......@@ -20,7 +20,7 @@ if(WITH_GRPC)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
return()
endif()
......
......@@ -73,10 +73,11 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
SendProcessor* s = new SendProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "Send", var_name_val, p_ctx, p_scope));
const std::string method = "SendRPC";
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req;
......@@ -87,10 +88,16 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
// stub context
s->response_call_back_ = nullptr;
platform::RecordEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
......@@ -122,10 +129,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope));
const std::string method = "GetRPC";
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, s, this] {
framework::AsyncIO([var_name_val, s, method, p_ctx, h, this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
......@@ -137,10 +145,16 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
// stub context
s->response_call_back_ = ProcGetResponse;
platform::RecordEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
......@@ -161,12 +175,14 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(
new VarHandle(ep, "Prefetch", out_var_name_val, p_ctx, p_scope));
const std::string method = "PrefetchRPC";
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
s, this] {
s, method, h, this] {
auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req;
......@@ -177,11 +193,17 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
// stub context
s->response_call_back_ = ProcGetResponse;
platform::RecordEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
......@@ -193,15 +215,24 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "BatchBarrier", BATCH_BARRIER_MESSAGE,
nullptr, nullptr));
const std::string method = "BatchBarrierRPC";
VarHandlePtr h(
new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
platform::RecordEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
......@@ -209,15 +240,24 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "FetchBarrier", FETCH_BARRIER_MESSAGE,
nullptr, nullptr));
const std::string method = "FetchBarrierRPC";
VarHandlePtr h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
platform::RecordEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
......@@ -226,15 +266,23 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
VarHandlePtr h(
new VarHandle(ep, "SendComplete", COMPLETE_MESSAGE, nullptr, nullptr));
const std::string method = "SendCompleteRPC";
VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE);
platform::RecordEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
......@@ -244,17 +292,27 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const auto ch = GetChannel(ep);
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "CheckPointNotify", CHECKPOINT_SAVE_MESSAGE,
nullptr, nullptr));
const std::string method = "CheckPointNotifyRPC";
VarHandlePtr h(
new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
platform::RecordEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
......@@ -273,6 +331,7 @@ void GRPCClient::Proceed() {
BaseProcessor* c = static_cast<BaseProcessor*>(tag);
GPR_ASSERT(ok);
PADDLE_ENFORCE(c);
if (c->status_.ok()) {
VLOG(3) << c->GetVarHandlePtr()->String() << " process";
c->Process();
......
......@@ -36,6 +36,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_name) {
platform::RecordEvent record_event("serial", &ctx);
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback destroy_callback = [](void* backing) {};
......@@ -147,6 +148,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var) {
platform::RecordEvent record_event("deserial", &ctx);
operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar();
......
......@@ -136,9 +136,9 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
// since infershape can not get lod info
PADDLE_ENFORCE_EQ(ref_lod.size(), 1UL, "Only support input lod size is 1.");
PADDLE_ENFORCE_EQ(in1_lod.size(), 1UL, "Only support input lod size is 1.");
PADDLE_ENFORCE_EQ(in1_lod[0].size() - 1, N,
PADDLE_ENFORCE_EQ(static_cast<int>(in1_lod[0].size() - 1), N,
"Batch size of all inputs should be equal.");
PADDLE_ENFORCE_EQ(in1_lod[0][N], N,
PADDLE_ENFORCE_EQ(static_cast<int>(in1_lod[0][N]), N,
"Seq_length of other inputs should be 1.");
PADDLE_ENFORCE_EQ(in1_dims[0], N, "input height should be batch size.");
for (size_t i = 2; i < ins.size(); ++i) {
......
......@@ -66,7 +66,7 @@ static void ParallelExecuteBlocks(
<< "pointer: " << prepared[run_block].get();
executor->RunPreparedContext(prepared[run_block].get(), scope);
} catch (const std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
LOG(FATAL) << "run sub program:" << idx << " error " << e.what();
}
}));
}
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <sys/time.h>
#include <cmath> // for exp
#include <cstring> // for memcpy
#include <random>
#include <string>
#include <vector>
#include "gflags/gflags.h"
......
......@@ -24,7 +24,7 @@ class MomentumOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(param) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
......@@ -45,12 +45,15 @@ class MomentumOp : public framework::OperatorWithKernel {
"Output(VelocityOut) of Momentum should not be null.");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and Grad input of MomentumOp should have the same dimension.");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Velocity"),
"Param and Velocity of MomentumOp should have the same dimension.");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and Grad input of MomentumOp should have the same dimension.");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Velocity"),
"Param and Velocity of MomentumOp should have the same dimension.");
}
PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1,
"Learning_rate should be a scalar");
......@@ -58,13 +61,34 @@ class MomentumOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("VelocityOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
const framework::ExecutionContext& ctx) const override {
auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class MomentumOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto input_var = op_desc.Input("Param")[0];
for (auto& out_var : op_desc.Output("ParamOut")) {
if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::SELECTED_ROWS);
} else if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
framework::proto::VarType::LOD_TENSOR) {
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::LOD_TENSOR);
} else {
PADDLE_THROW(
"Only support LodTensor and SelectedRows, Unexpected Input Type.");
}
}
}
};
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -115,6 +139,9 @@ $$
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(momentum, ops::MomentumOp, ops::MomentumOpMaker);
REGISTER_OP_CPU_KERNEL(momentum, ops::MomentumOpKernel<float>,
ops::MomentumOpKernel<double>);
REGISTER_OPERATOR(momentum, ops::MomentumOp, ops::MomentumOpMaker,
paddle::framework::EmptyGradOpMaker,
ops::MomentumOpInferVarType);
REGISTER_OP_CPU_KERNEL(
momentum, ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -15,76 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/momentum_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void MomentumKernel(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu,
const int64_t num, bool use_nesterov, T* p_out,
T* v_out) {
T lr = learning_rate[0];
if (use_nesterov) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
T g_val = g[i];
T v_new = v[i] * mu + g_val;
v_out[i] = v_new;
p_out[i] = p[i] - (g_val + v_new * mu) * lr;
}
} else {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
T v_new = v[i] * mu + g[i];
v_out[i] = v_new;
p_out[i] = p[i] - lr * v_new;
}
}
}
template <typename T>
class MomentumOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
auto param = ctx.Input<framework::Tensor>("Param");
auto velocity = ctx.Input<framework::Tensor>("Velocity");
auto grad = ctx.Input<framework::Tensor>("Grad");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
T mu = static_cast<T>(ctx.Attr<float>("mu"));
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto* p = param->data<T>();
auto* v = velocity->data<T>();
auto* g = grad->data<T>();
auto* lr = learning_rate->data<T>();
int block = 512;
int grid = (param->numel() + block - 1) / block;
MomentumKernel<T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
p, g, v, lr, mu, param->numel(), use_nesterov, p_out, v_out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(momentum, ops::MomentumOpCUDAKernel<float>,
ops::MomentumOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
momentum, ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -13,35 +13,48 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename T>
class MomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
auto param = ctx.Input<framework::Tensor>("Param");
auto velocity = ctx.Input<framework::Tensor>("Velocity");
auto grad = ctx.Input<framework::Tensor>("Grad");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
using framework::Tensor;
using framework::SelectedRows;
struct NoNesterov;
struct UseNesterov;
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
template <typename T>
class CPUDenseMomentumFunctor {
private:
const Tensor* param;
const Tensor* grad;
const Tensor* velocity;
const Tensor* learning_rate;
const T mu;
const T use_nesterov;
Tensor* param_out;
Tensor* velocity_out;
T mu = static_cast<T>(ctx.Attr<float>("mu"));
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
public:
CPUDenseMomentumFunctor(const Tensor* param, const Tensor* grad,
const Tensor* velocity, const Tensor* learning_rate,
const T mu, const bool use_nesterov,
Tensor* param_out, Tensor* velocity_out)
: param(param),
grad(grad),
velocity(velocity),
learning_rate(learning_rate),
mu(mu),
use_nesterov(use_nesterov),
param_out(param_out),
velocity_out(velocity_out) {}
inline void operator()() {
auto p_out = framework::EigenVector<T>::Flatten(*param_out);
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);
......@@ -59,5 +72,283 @@ class MomentumOpKernel : public framework::OpKernel<T> {
}
};
template <typename T, typename UpdateMethod>
class DenseMomentumFunctor;
// NOTE(dzh) for performance.
// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two
// functor.
template <typename T>
class DenseMomentumFunctor<T, UseNesterov> {
private:
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const T mu_;
const int64_t num_;
T* p_out_;
T* v_out_;
public:
DenseMomentumFunctor(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu, const int64_t num,
T* p_out, T* v_out)
: p_(p),
g_(g),
v_(v),
lr_(learning_rate),
mu_(mu),
num_(num),
p_out_(p_out),
v_out_(v_out) {}
inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register
const T p = p_[i];
const T g = g_[i];
const T lr = lr_[0];
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - (g + v_out * mu_) * lr;
// write reigster to memory
v_out_[i] = v_out;
p_out_[i] = p_out;
}
};
template <typename T>
class DenseMomentumFunctor<T, NoNesterov> {
private:
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const T mu_;
const int64_t num_;
T* p_out_;
T* v_out_;
public:
DenseMomentumFunctor(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu, const int64_t num,
T* p_out, T* v_out)
: p_(p),
g_(g),
v_(v),
lr_(learning_rate),
mu_(mu),
num_(num),
p_out_(p_out),
v_out_(v_out) {}
inline HOSTDEVICE void operator()(size_t i) const {
// put memory access in register
const T p = p_[i];
const T g = g_[i];
const T lr = lr_[0];
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - lr * v_out;
// write reigster to memory
v_out_[i] = v_out;
p_out_[i] = p_out;
}
};
template <typename T, typename UpdateMethod>
class SparseMomentumFunctor;
template <typename T>
class SparseMomentumFunctor<T, UseNesterov> {
private:
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const T mu_;
const int64_t* rows_;
const int64_t row_numel_;
const int64_t row_height_;
T* p_out_;
T* v_out_;
public:
SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr,
const T mu, const int64_t* rows, int64_t row_numel,
int64_t row_height, T* p_out, T* v_out)
: p_(p),
g_(g),
v_(v),
lr_(lr),
mu_(mu),
rows_(rows),
row_numel_(row_numel),
row_height_(row_height),
p_out_(p_out),
v_out_(v_out) {}
inline HOSTDEVICE void operator()(size_t i) {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0;
// put memory access in register
const T p = p_[i];
const T lr = lr_[0];
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - (g + v_out * mu_) * lr;
// write reigster to memory
v_out_[i] = v_out;
p_out_[i] = p_out;
}
};
template <typename T>
class SparseMomentumFunctor<T, NoNesterov> {
private:
const T* p_;
const T* g_;
const T* v_;
const T* lr_;
const T mu_;
const int64_t* rows_;
const int64_t row_numel_;
const int64_t row_height_;
T* p_out_;
T* v_out_;
public:
SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr,
const T mu, const int64_t* rows, int64_t row_numel,
int64_t row_height, T* p_out, T* v_out)
: p_(p),
g_(g),
v_(v),
lr_(lr),
mu_(mu),
rows_(rows),
row_numel_(row_numel),
row_height_(row_height),
p_out_(p_out),
v_out_(v_out) {}
inline HOSTDEVICE void operator()(size_t i) {
auto row_idx =
math::BinarySearch<int64_t>(rows_, row_height_, i / row_numel_);
T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0;
// put memory access in register
const T p = p_[i];
const T lr = lr_[0];
const T v = v_[i];
T v_out = v * mu_ + g;
T p_out = p - v_out * lr;
// write reigster to memory
v_out_[i] = v_out;
p_out_[i] = p_out;
}
};
template <typename DeviceContext, typename T>
class MomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
T mu = static_cast<T>(ctx.Attr<float>("mu"));
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto param = ctx.Input<framework::Tensor>("Param");
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* velocity = ctx.Input<framework::Tensor>("Velocity");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
auto* grad_var = ctx.InputVar("Grad");
if (grad_var->IsType<framework::LoDTensor>()) {
auto grad = ctx.Input<framework::Tensor>("Grad");
if (platform::is_cpu_place(ctx.GetPlace())) {
CPUDenseMomentumFunctor<T> functor(param, grad, velocity, learning_rate,
mu, use_nesterov, param_out,
velocity_out);
functor();
} else if (platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param->numel());
if (use_nesterov) {
DenseMomentumFunctor<T, UseNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<T>(),
learning_rate->data<T>(), mu, param->numel(),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
} else {
DenseMomentumFunctor<T, NoNesterov> functor(
param->data<T>(), grad->data<T>(), velocity->data<T>(),
learning_rate->data<T>(), mu, param->numel(),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
}
}
} else if (grad_var->IsType<framework::SelectedRows>()) {
// sparse update embedding with selectedrows
auto grad = ctx.Input<framework::SelectedRows>("Grad");
// sparse update maybe empty.
if (grad->rows().size() == 0) {
VLOG(3) << "Grad SelectedRows contains no data!";
return;
}
auto* merged_grad = const_cast<framework::Scope&>(ctx.scope())
.Var()
->GetMutable<framework::SelectedRows>();
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(ctx.template device_context<DeviceContext>(), *grad,
merged_grad);
const int64_t* rows = nullptr;
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(ctx.GetPlace())) {
rows = merged_grad->rows().CUDAData(ctx.GetPlace());
} else {
#endif
rows = merged_grad->rows().data();
#ifdef PADDLE_WITH_CUDA
}
#endif
int64_t row_numel =
merged_grad->value().numel() / merged_grad->rows().size();
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param->numel());
if (use_nesterov) {
SparseMomentumFunctor<T, UseNesterov> functor(
param->data<T>(), merged_grad->value().data<T>(),
velocity->data<T>(), learning_rate->data<T>(), mu, rows, row_numel,
static_cast<int64_t>(merged_grad->rows().size()),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
} else {
SparseMomentumFunctor<T, NoNesterov> functor(
param->data<T>(), merged_grad->value().data<T>(),
velocity->data<T>(), learning_rate->data<T>(), mu, rows, row_numel,
static_cast<int64_t>(merged_grad->rows().size()),
param_out->mutable_data<T>(ctx.GetPlace()),
velocity_out->mutable_data<T>(ctx.GetPlace()));
for_range(functor);
}
} else {
PADDLE_THROW(
string::Sprintf("MomentumOp only supports LoDTensor or SelectedRows "
"gradient, but the received Variable Type is %s",
grad_var->Type().name()));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -43,17 +43,31 @@ class SumKernel : public framework::OpKernel<T> {
out->mutable_data<T>(context.GetPlace());
}
auto result = EigenVector<T>::Flatten(*out);
auto &place =
*context.template device_context<DeviceContext>().eigen_device();
int start = in_place ? 1 : 0;
if (!in_place) {
math::SetConstant<DeviceContext, T> constant_functor;
constant_functor(context.template device_context<DeviceContext>(), out,
0.0);
if ((in_num >= 2) && in_vars[0]->IsType<framework::LoDTensor>() &&
in_vars[1]->IsType<framework::LoDTensor>()) {
auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
if (in_0.numel() && in_1.numel()) {
auto in_0_e = EigenVector<T>::Flatten(in_0);
auto in_1_e = EigenVector<T>::Flatten(in_1);
result.device(place) = in_0_e + in_1_e;
start = 2;
}
}
if (start != 2) {
math::SetConstant<DeviceContext, T> constant_functor;
constant_functor(context.template device_context<DeviceContext>(),
out, 0.0);
}
}
math::SelectedRowsAddToTensor<DeviceContext, T> functor;
auto &place =
*context.template device_context<DeviceContext>().eigen_device();
// If in_place, just skip the first tensor
for (size_t i = in_place ? 1 : 0; i < in_num; i++) {
for (size_t i = start; i < in_num; i++) {
if (in_vars[i]->IsType<framework::LoDTensor>()) {
auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
if (in_t.numel() == 0) {
......
......@@ -71,6 +71,7 @@ void PopEvent(const std::string& name, const DeviceContext* dev_ctx);
#if !defined(_WIN32)
struct RecordEvent {
// dev_ctx can be set to nullptr if device is cpu.
RecordEvent(const std::string& name, const DeviceContext* dev_ctx);
~RecordEvent();
......
......@@ -390,7 +390,9 @@ function run_mac_test() {
Running unit tests ...
========================================
EOF
#remove proxy here to fix dist error on mac
export http_proxy=
export https_proxy=
# TODO: jiabin need to refine this part when these tests fixed on mac
ctest --output-on-failure -j $1
# make install should also be test when unittest
......@@ -659,6 +661,7 @@ function gen_fluid_lib() {
EOF
cmake .. -DWITH_DISTRIBUTE=OFF
make -j `nproc` fluid_lib_dist
make -j `nproc` inference_lib_dist
fi
}
......@@ -672,6 +675,8 @@ EOF
cd ${PADDLE_ROOT}/build
cp -r fluid_install_dir fluid
tar -czf fluid.tgz fluid
cp -r fluid_inference_install_dir fluid_inference
tar -czf fluid_inference.tgz fluid_inference
fi
}
......@@ -683,7 +688,9 @@ function test_fluid_lib() {
========================================
EOF
cd ${PADDLE_ROOT}/paddle/fluid/inference/api/demo_ci
./run.sh ${PADDLE_ROOT} ${WITH_MKL:-ON} ${WITH_GPU:-OFF} ${INFERENCE_DEMO_INSTALL_DIR} ${TENSORRT_INCLUDE_DIR:-/usr/local/TensorRT/include} ${TENSORRT_LIB_DIR:-/usr/local/TensorRT/lib}
./run.sh ${PADDLE_ROOT} ${WITH_MKL:-ON} ${WITH_GPU:-OFF} ${INFERENCE_DEMO_INSTALL_DIR} \
${TENSORRT_INCLUDE_DIR:-/usr/local/TensorRT/include} \
${TENSORRT_LIB_DIR:-/usr/local/TensorRT/lib}
./clean.sh
fi
}
......
......@@ -1522,13 +1522,17 @@ class Program(object):
>>> with program.lr_schedule_guard():
>>> lr = lr * decay
"""
tmp_role = self._current_role
tmp_var = self._op_role_var
OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.LRSched
# TODO(typhoonzero): how to set target learning rate var
self._op_role_var = []
yield
self._op_role_var = []
self._current_role = OpRole.Forward
self._op_role_var = tmp_var
self._current_role = tmp_role
def __str__(self):
"""
......
......@@ -153,6 +153,7 @@ __all__ = [
'mul',
'sigmoid_cross_entropy_with_logits',
'maxout',
'affine_channel',
]
......@@ -7268,3 +7269,44 @@ def maxout(x, groups, name=None):
attrs={"groups": groups},
outputs={"Out": out})
return out
def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
"""
Applies a separate affine transformation to each channel of the input.
Useful for replacing spatial batch norm with its equivalent fixed
transformation. The input also can be 2D tensor and applies a affine
transformation in second dimension.
Args:
x (Variable): Feature map input can be a 4D tensor with order NCHW
or NHWC. It also can be a 2D tensor and the affine transformation
is applied in the second dimension.
scale (Variable): 1D input of shape (C), the c-th element is the scale
factor of the affine transformation for the c-th channel of
the input.
bias (Variable): 1D input of shape (C), the c-th element is the bias
of the affine transformation for the c-th channel of the input.
data_layout (string, default NCHW): NCHW or NHWC. If input is 2D
tensor, you can ignore data_layout.
name (str, default None): The name of this layer.
Returns:
out (Variable): A tensor of the same shape and data layout with x.
"""
helper = LayerHelper("affine_channel", **locals())
if name is None:
out = helper.create_tmp_variable(dtype=x.dtype)
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="affine_channel",
inputs={"X": x,
'Scale': scale,
'Bias': bias},
attrs={"data_layout": data_layout},
outputs={"Out": out})
return out
......@@ -15,7 +15,7 @@
from __future__ import print_function
import re
from collections import defaultdict
from paddle.fluid.framework import Program, Variable, name_scope
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program
from . import framework
from . import layers
from .backward import append_backward
......@@ -111,7 +111,8 @@ class Optimizer(object):
if param_lr == 1.0:
return self._global_learning_rate()
else:
return self._global_learning_rate() * param_lr
with default_main_program()._lr_schedule_guard():
return self._global_learning_rate() * param_lr
def _create_accumulators(self, block, parameters):
"""Create all accumulators needed by the parameters
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
def affine_channel(x, scale, bias, layout):
C = x.shape[1] if layout == 'NCHW' else x.shape[-1]
if len(x.shape) == 4:
new_shape = (1, C, 1, 1) if layout == 'NCHW' else (1, 1, 1, C)
else:
new_shape = (1, C)
scale = scale.reshape(new_shape)
bias = bias.reshape(new_shape)
return x * scale + bias
class TestAffineChannelOp(OpTest):
def setUp(self):
self.op_type = "affine_channel"
self.init_test_case()
x = np.random.random(self.shape).astype("float32")
scale = np.random.random(self.C).astype("float32")
bias = np.random.random(self.C).astype("float32")
y = affine_channel(x, scale, bias, self.layout)
self.inputs = {'X': x, 'Scale': scale, 'Bias': bias}
self.attrs = {'data_layout': self.layout}
self.outputs = {'Out': y}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'Scale', 'Bias'], 'Out')
def test_check_grad_stopgrad_dx(self):
self.check_grad(['Scale', 'Bias'], 'Out', no_grad_set=set('X'))
def test_check_grad_stopgrad_dscale_dbias(self):
self.check_grad(['X'], 'Out', no_grad_set=set(['Scale', 'Bias']))
def init_test_case(self):
self.shape = [2, 32, 14, 14]
self.C = 32
self.layout = 'NCHW'
class TestAffineChannelNHWC(TestAffineChannelOp):
def init_test_case(self):
self.shape = [2, 14, 14, 32]
self.C = 32
self.layout = 'NHWC'
class TestAffineChannel2D(TestAffineChannelOp):
def init_test_case(self):
self.shape = [16, 64]
self.C = 64
self.layout = 'NCHW'
class TestAffineChannelNCHWLargeShape(TestAffineChannelOp):
def init_test_case(self):
self.shape = [64, 128, 112, 112]
self.C = 128
self.layout = 'NCHW'
# since the gradient check is very slow in large shape, so skip check_grad
def test_check_grad(self):
pass
def test_check_grad_stopgrad_dx(self):
pass
def test_check_grad_stopgrad_dscale_dbias(self):
pass
class TestAffineChannelNCHWLargeShape(TestAffineChannelNCHWLargeShape):
def init_test_case(self):
self.shape = [64, 112, 112, 512]
self.C = 512
self.layout = 'NHWC'
if __name__ == '__main__':
unittest.main()
......@@ -91,6 +91,8 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase):
need_envs=need_envs)
# FIXME(tangwei): Learningrate variable is not created on pserver.
"""
class TestDistSimnetBow2x2LookupTableSync(TestDistBase):
def _setup_config(self):
self._sync_mode = True
......@@ -105,7 +107,7 @@ class TestDistSimnetBow2x2LookupTableSync(TestDistBase):
self.check_with_place(
"dist_simnet_bow.py",
delta=1e-5,
check_error_log=False,
check_error_log=True,
need_envs=need_envs)
......@@ -143,7 +145,7 @@ class TestDistSimnetBow2x2LookupTableNotContainLRSync(TestDistBase):
delta=1e-5,
check_error_log=False,
need_envs=need_envs)
"""
if __name__ == "__main__":
unittest.main()
......@@ -16,6 +16,8 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from op_test import OpTest
......@@ -88,5 +90,97 @@ class TestMomentumOp2(OpTest):
self.check_output()
class TestSparseMomentumOp(unittest.TestCase):
def setUp(self):
self.use_nesterov = False
def check_with_place(self, place):
self.init_kernel()
scope = core.Scope()
# create and initialize Grad Variable
height = 10
rows = [0, 4, 7]
row_numel = 12
mu = 1.0
use_nesterov = self.use_nesterov
# create and initialize Param Variable
param = scope.var('Param').get_tensor()
param_array = np.full((height, row_numel), 5.0).astype("float32")
param.set(param_array, place)
param_out = scope.var("ParamOut").get_tensor()
param_out_array = np.full((height, row_numel), 0.0).astype("float32")
param_out.set(param_out_array, place)
grad_selected_rows = scope.var('Grad').get_selected_rows()
grad_selected_rows.set_height(height)
grad_selected_rows.set_rows(rows)
grad_np_array = np.ones((len(rows), row_numel)).astype("float32")
grad_np_array[0, 0] = 2.0
grad_np_array[2, 8] = 4.0
grad_tensor = grad_selected_rows.get_tensor()
grad_tensor.set(grad_np_array, place)
velocity = scope.var('Velocity').get_tensor()
velocity_np_array = np.ones((height, row_numel)).astype("float32")
velocity.set(velocity_np_array, place)
velocity_out = scope.var('VelocityOut').get_tensor()
velocity_out_np_array = np.full((height, row_numel),
0.0).astype("float32")
velocity_out.set(velocity_out_np_array, place)
# create and initialize LeraningRate Variable
lr = scope.var('LearningRate').get_tensor()
lr_array = np.full((1), 2.0).astype("float32")
lr.set(lr_array, place)
# create and run operator
op = Operator(
"momentum",
Param='Param',
Grad='Grad',
Velocity='Velocity',
ParamOut='ParamOut',
VelocityOut='VelocityOut',
LearningRate='LearningRate',
mu=mu,
use_nesterov=use_nesterov)
op.run(scope, place)
# get and compare result
param_out_np_array = np.array(param_out)
velocity_out_np_array = np.array(velocity_out)
# TODO(dzh): add a more suitable general numpy interface
# for sparse update.
_grad_np_array = np.full((height, row_numel), 0.0).astype("float32")
for i in range(len(rows)):
_grad_np_array[rows[i]] = grad_np_array[i]
_velocity_out = mu * velocity_np_array + _grad_np_array
_param = param_array
if use_nesterov:
_param_out = _param - (_grad_np_array + _velocity_out * mu
) * lr_array
else:
_param_out = _param - lr_array * _velocity_out
self.assertTrue((_velocity_out == velocity_out_np_array).all())
self.assertTrue((_param_out == param_out_np_array).all())
def init_kernel(self):
pass
def test_sparse_momentum(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
class TestSparseMomentumOp2(TestSparseMomentumOp):
def init_kernel(self):
self.use_nesterov = True
if __name__ == "__main__":
unittest.main()
......@@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['dump_config']
from plot import Ploter
__all__ = ['dump_config', 'Ploter']
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
class PlotData(object):
def __init__(self):
self.step = []
self.value = []
def append(self, step, value):
self.step.append(step)
self.value.append(value)
def reset(self):
self.step = []
self.value = []
class Ploter(object):
"""
Plot input data in a 2D graph
Args:
title: assign the title of input data.
step: x_axis of the data.
value: y_axis of the data.
"""
def __init__(self, *args):
self.__args__ = args
self.__plot_data__ = {}
for title in args:
self.__plot_data__[title] = PlotData()
# demo in notebooks will use Ploter to plot figure, but when we convert
# the ipydb to py file for testing, the import of matplotlib will make the
# script crash. So we can use `export DISABLE_PLOT=True` to disable import
# these libs
self.__disable_plot__ = os.environ.get("DISABLE_PLOT")
if not self.__plot_is_disabled__():
import matplotlib.pyplot as plt
from IPython import display
self.plt = plt
self.display = display
def __plot_is_disabled__(self):
return self.__disable_plot__ == "True"
def append(self, title, step, value):
"""
Feed data
Args:
title: assign the group data to this subtitle.
step: the x_axis of data.
value: the y_axis of data.
Examples:
.. code-block:: python
plot_curve = Ploter("Curve 1","Curve 2")
plot_curve.append(title="Curve 1",step=1,value=1)
"""
assert isinstance(title, basestring)
assert self.__plot_data__.has_key(title)
data = self.__plot_data__[title]
assert isinstance(data, PlotData)
data.append(step, value)
def plot(self, path=None):
"""
Plot data in a 2D graph
Args:
path: store the figure to this file path. Defaul None.
Examples:
.. code-block:: python
plot_curve = Ploter()
plot_cure.plot()
"""
if self.__plot_is_disabled__():
return
titles = []
for title in self.__args__:
data = self.__plot_data__[title]
assert isinstance(data, PlotData)
if len(data.step) > 0:
titles.append(title)
self.plt.plot(data.step, data.value)
self.plt.legend(titles, loc='upper left')
if path is None:
self.display.clear_output(wait=True)
self.display.display(self.plt.gcf())
else:
self.plt.savefig(path)
self.plt.gcf().clear()
def reset(self):
for key in self.__plot_data__:
data = self.__plot_data__[key]
assert isinstance(data, PlotData)
data.reset()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册