提交 686b8935 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_floordiv_and_mod

......@@ -205,7 +205,7 @@ paddle.fluid.layers.maxout (ArgSpec(args=['x', 'groups', 'name'], varargs=None,
paddle.fluid.layers.space_to_depth (ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '5f207ae10589ebe38a63575ef6ff8e1e'))
paddle.fluid.layers.affine_grid (ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '51def402b8910e163cbace9d0c0526ed'))
paddle.fluid.layers.sequence_reverse (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '77a6d80aa5551ca70324fc975c44507f'))
paddle.fluid.layers.affine_channel (ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)), ('document', '2f46f1ff39a13ab00857e7b9f44b2fa7'))
paddle.fluid.layers.affine_channel (ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name', 'act'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None, None)), ('document', 'ab84fdc6dc60f3ad9aa397e6007e3bf9'))
paddle.fluid.layers.similarity_focus (ArgSpec(args=['input', 'axis', 'indexes', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '70e3b5182a18b40b47ecabd7c8490a35'))
paddle.fluid.layers.hash (ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None)), ('document', '9bb77f8dc002dd2ce75d4769eaaf5007'))
paddle.fluid.layers.grid_sampler (ArgSpec(args=['x', 'grid', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd256cba1c41a5ed92ce3f31e24a2ca6d'))
......
cc_library(benchmark SRCS benchmark.cc DEPS enforce)
cc_test(test_benchmark SRCS benchmark_tester.cc DEPS benchmark)
cc_binary(visualizer SRCS visualizer.cc DEPS analysis
paddle_pass_builder ir_pass_manager pass graph_viz_pass analysis_passes)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/utils/visualizer.h"
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <fstream>
#include <memory>
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h"
#include "paddle/fluid/platform/init.h"
DEFINE_string(model_dir, "", "model directory");
DEFINE_string(model_program_path, "", "model program path");
DEFINE_string(model_params_path, "", "model params path");
using paddle::inference::analysis::Argument;
namespace paddle {
namespace inference {
namespace utils {
void Visualizer::SetArgument(Argument *argument) { argument_ = argument; }
bool Visualizer::Run() {
paddle::framework::InitDevices(false);
paddle::inference::analysis::Analyzer().Run(argument_);
return true;
}
} // namespace utils
} // namespace inference
} // namespace paddle
// Generate a dot file describing the structure of graph.
// To use this tool, run command: ./visualizer [options...]
// Options:
// --model_dir: the directory of model
// --model_program_path: the path of program
// --model_params_path: the path of params
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
paddle::inference::analysis::Argument argument;
argument.SetUseGPU(false);
argument.SetUseTensorRT(false);
if (FLAGS_model_dir.empty()) {
if (FLAGS_model_program_path.empty() || FLAGS_model_params_path.empty()) {
LOG(ERROR) << "Please set model_dir"
" or model_program_path and model_params_path";
return -1;
} else {
argument.SetModelProgramPath(FLAGS_model_program_path);
argument.SetModelParamsPath(FLAGS_model_params_path);
}
} else {
argument.SetModelDir(FLAGS_model_dir);
}
// Only 1 pass, default filename is 0_ir_origin.dot
// For more details, looking for paddle::inference::analysis::IRPassManager
argument.SetIrAnalysisPasses({"infer_clean_graph_pass", "graph_viz_pass"});
std::unique_ptr<paddle::framework::Scope> scope{
new paddle::framework::Scope()};
argument.SetScopeNotOwned(
const_cast<paddle::framework::Scope *>(scope.get()));
paddle::inference::utils::Visualizer visualizer;
visualizer.SetArgument(&argument);
visualizer.Run();
return 0;
}
USE_PASS(infer_clean_graph_pass);
USE_PASS(graph_viz_pass);
USE_PASS(graph_to_program_pass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/inference/analysis/argument.h"
namespace paddle {
namespace inference {
namespace utils {
using paddle::inference::analysis::Argument;
class Visualizer final {
public:
Visualizer() = default;
~Visualizer() = default;
Visualizer(const Visualizer &) = delete;
Visualizer &operator=(const Visualizer &) = delete;
void SetArgument(Argument *);
bool Run();
private:
Argument *argument_;
};
} // namespace utils
} // namespace inference
} // namespace paddle
......@@ -113,6 +113,27 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel {
}
};
class AffineChannelGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("affine_channel_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("Scale", Input("Scale"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetOutput(framework::GradVarName("Scale"), InputGrad("Scale"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
template <typename T>
using EigenArrayMap =
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
......@@ -260,8 +281,7 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(affine_channel, ops::AffineChannelOp,
ops::AffineChannelOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::AffineChannelOpMaker, ops::AffineChannelGradMaker);
REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelOpGrad);
REGISTER_OP_CPU_KERNEL(affine_channel, ops::AffineChannelKernel<CPU, float>,
......
......@@ -75,15 +75,11 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
// TODO(dengkaipeng): use PERSISTENT mode in training may incur errors
// in inference period, cuDNN fixed issues on PERSISTENT mode in version
// 7.0.2, 7.0.4 and 7.3.0, we disable this mode currently.
// #if CUDNN_VERSION_MIN(7, 0, 0)
// mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
// #else
#if CUDNN_VERSION_MIN(7, 0, 0)
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
mode_ = CUDNN_BATCHNORM_SPATIAL;
// #endif
#endif
VLOG(3) << "Setting descriptors.";
std::vector<int> dims;
......@@ -305,15 +301,11 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
<< "CUDNN_BN_MIN_EPSILON instead.";
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
// TODO(dengkaipeng): use PERSISTENT mode in training may incur errors
// in inference period, cuDNN fixed issues on PERSISTENT mode in version
// 7.0.2, 7.0.4 and 7.3.0, we disable this mode currently.
// #if CUDNN_VERSION_MIN(7, 0, 0)
// mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
// #else
#if CUDNN_VERSION_MIN(7, 0, 0)
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#else
mode_ = CUDNN_BATCHNORM_SPATIAL;
// #endif
#endif
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
......
......@@ -32,7 +32,10 @@ class LoDResetOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GT(level0.size(), 1,
"If Input(Y) not provided, the target lod should be "
"specified by attribute `target_lod`.");
} else {
ctx->ShareLoD("Y", "Out");
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
......
......@@ -78,12 +78,6 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
"The numel of 'pad_value' can only be 1 or be equal to the "
"'step_width'.");
if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) {
TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor);
pad_tensor->Resize(pad_tensor_dims);
return;
}
const int kBlockSize = 512;
/* At least use 32 threads to copy sequence_width elements,
......@@ -129,12 +123,13 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
step_width, layout);
/*
if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) {
TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor);
seq_tensor->Resize(seq_tensor_dims);
return;
}
*/
const int kBlockSize = 512;
......
......@@ -290,8 +290,10 @@ class MatMulOp : public framework::OperatorWithKernel {
context->Attrs().Get<bool>("transpose_Y"));
PADDLE_ENFORCE_EQ(mat_dim_x.width_, mat_dim_y.height_);
PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
if (context->IsRuntime()) {
PADDLE_ENFORCE(mat_dim_x.batch_size_ == mat_dim_y.batch_size_ ||
mat_dim_x.batch_size_ == 0 || mat_dim_y.batch_size_ == 0);
}
std::vector<int64_t> dim_out;
if (mat_dim_x.batch_size_ != 0) {
dim_out = framework::vectorize(dim_x);
......
......@@ -453,6 +453,7 @@ function assert_api_spec_approvals() {
echo "checking ${API_FILE} change, PR: ${GIT_PR_ID}, changes: ${API_CHANGE}"
if [ ${API_CHANGE} ] && [ "${GIT_PR_ID}" != "" ]; then
# NOTE: per_page=10000 should be ok for all cases, a PR review > 10000 is not human readable.
# approval_user_list: velconia 1979255,panyx0718 2887803,XiaoguangHu01 46782768,chengduoZH 30176695,Xreki 12538138,luotao1 6836917,sneaxiy 32832641,tensor-tang 21351065,jacquesqiao 3048612,typhoonzero 13348433,shanyi15 35982308.
if [ "$API_FILE" == "paddle/fluid/API.spec" ];then
APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \
python ${PADDLE_ROOT}/tools/check_pr_approval.py 2 2887803 35982308 46782768 30176695`
......@@ -462,14 +463,14 @@ function assert_api_spec_approvals() {
fi
else
APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \
python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 2887803`
python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 2887803 1979255 21351065 3048612 13348433 46782768 30176695 12538138 6836917 32832641`
fi
echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}"
if [ "${APPROVALS}" == "FALSE" ]; then
if [ "$API_FILE" == "paddle/fluid/API.spec" ];then
echo "You must have one RD (panyx0718 or chengduoZH or XiaoguangHu01) and one PM (shanyi15) approval for the api change! ${API_FILE}"
else
echo "You must have panyx0718 approval for the api change! ${API_FILE}"
echo "You must have one RD (velconia,panyx0718,XiaoguangHu01,chengduoZH,Xreki,luotao1,sneaxiy,tensor-tang,jacquesqiao,typhoonzero) approval for the api change! ${API_FILE}"
fi
exit 1
fi
......@@ -479,10 +480,10 @@ function assert_api_spec_approvals() {
HAS_CONST_CAST=`git diff -U0 upstream/$BRANCH |grep -o -m 1 "const_cast" || true`
if [ ${HAS_CONST_CAST} ] && [ "${GIT_PR_ID}" != "" ]; then
APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \
python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 2887803`
python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 2887803 1979255 21351065 3048612 13348433 46782768 30176695 12538138 6836917 32832641`
echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}"
if [ "${APPROVALS}" == "FALSE" ]; then
echo "You must have panyx0718 approval for the const_cast"
echo "You must have one RD (velconia,panyx0718,XiaoguangHu01,chengduoZH,Xreki,luotao1,sneaxiy,tensor-tang,jacquesqiao,typhoonzero) approval for the api change! ${API_FILE}"
exit 1
fi
fi
......
......@@ -4901,6 +4901,9 @@ def matmul(x, y, transpose_x=False, transpose_y=False, alpha=1.0, name=None):
if len(y_shape) > 2 and len(x_shape) > 2:
for i, dim_x in enumerate(x_shape[:-2]):
# don't check neg shape
if dim_x < 0 or y_shape[i] < 0:
continue
if dim_x != y_shape[i]:
raise ValueError("Invalid inputs for matmul. x(%s), y(%s)" %
(x.shape, y.shape))
......@@ -9721,7 +9724,12 @@ def sequence_reverse(x, name=None):
return out
def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
def affine_channel(x,
scale=None,
bias=None,
data_layout='NCHW',
name=None,
act=None):
"""
Applies a separate affine transformation to each channel of the input.
Useful for replacing spatial batch norm with its equivalent fixed
......@@ -9740,6 +9748,7 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
data_layout (string, default NCHW): NCHW or NHWC. If input is 2D
tensor, you can ignore data_layout.
name (str, default None): The name of this layer.
act (str, default None): Activation to be applied to the output of this layer.
Returns:
out (Variable): A tensor of the same shape and data layout with x.
......@@ -9759,7 +9768,7 @@ def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None):
'Bias': bias},
attrs={"data_layout": data_layout},
outputs={"Out": out})
return out
return helper.append_activation(pre_activation)
def similarity_focus(input, axis, indexes, name=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册