提交 5670e9ea 编写于 作者: P peizhilin

Merge remote-tracking branch 'upstream/develop' into windows/build

......@@ -57,6 +57,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
desc.SetOutput("Out", std::vector<std::string>({fc_out_out}));
desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out});
......
......@@ -29,6 +29,7 @@ void SetOp(ProgramDesc* prog, const std::string& type,
if (type == "mul") {
op->SetInput("X", {inputs[0]});
op->SetInput("Y", {inputs[1]});
op->SetAttr("x_num_col_dims", {1});
} else if (type == "elementwise_add") {
op->SetInput("X", inputs);
}
......
......@@ -412,7 +412,7 @@ void DetachDeletedNodes(framework::ir::Graph *graph) {
void SubGraphFuser::ReplaceNodesWithSubGraphs() {
auto subgraphs = SubgraphDetector(graph_, node_inside_subgraph_teller_)();
for (auto &subgraph : subgraphs) {
if (subgraph.size() <= min_subgraph_size_) continue;
if (subgraph.size() <= (size_t)min_subgraph_size_) continue;
LOG(INFO) << "detect a subgraph size " << subgraph.size();
std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end());
// replace this sub-graph with the first node. Two steps: 1. Create a Block
......
......@@ -114,7 +114,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
// it is either an OP's input or an OP's output.
auto &subgraph_nodes = *Agent(node).subgraph();
for (int index = 0; index < block_desc.OpSize(); index++) {
for (size_t index = 0; index < block_desc.OpSize(); index++) {
framework::proto::OpDesc *op = block_desc.Op(index)->Proto();
auto correspond_node = subgraph_nodes[index];
PADDLE_ENFORCE_EQ(correspond_node->Name(), op->type());
......
......@@ -45,7 +45,7 @@ void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) {
std::unordered_set<std::string> teller_set(
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "dropout"});
"elementwise_add", "dropout", "split"});
if (!node->IsOp()) return false;
if (teller_set.count(node->Op()->Type())) {
......
......@@ -548,4 +548,5 @@ USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout);
USE_TRT_CONVERTER(pad);
USE_TRT_CONVERTER(split);
#endif
......@@ -15,7 +15,7 @@
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <thread>
#include <thread> // NOLINT
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
......
......@@ -23,7 +23,7 @@ limitations under the License. */
#include <memory>
#include <thread> //NOLINT
#include "utils.h"
#include "utils.h" // NOLINT
DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_bool(use_gpu, false, "Whether use gpu.");
......
......@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......
......@@ -49,6 +49,8 @@ struct AnalysisConfig : public NativeConfig {
void EnableTensorRtEngine(int workspace_size = 1 << 20,
int max_batch_size = 1);
bool use_tensorrt() const { return use_tensorrt_; }
// NOTE this is just for internal development, please not use it.
// NOT stable yet.
void EnableMKLDNN();
......
......@@ -91,7 +91,7 @@ class CpuPassStrategy : public PassStrategy {
virtual ~CpuPassStrategy() = default;
virtual void EnableMKLDNN() override {
void EnableMKLDNN() override {
// TODO(Superjomn) Consider the way to mix CPU with GPU.
#ifdef PADDLE_WITH_MKLDNN
passes_.insert(passes_.begin(), "mkldnn_placement_pass");
......@@ -123,7 +123,7 @@ class GpuPassStrategy : public PassStrategy {
GpuPassStrategy(const GpuPassStrategy &other)
: PassStrategy(other.AllPasses()) {}
virtual void EnableMKLDNN() override;
void EnableMKLDNN() override;
virtual ~GpuPassStrategy() = default;
};
......
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto device_context)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
add_subdirectory(plugin)
add_subdirectory(convert)
# Add TRT tests
nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc
DEPS tensorrt_engine operator scope framework_proto op_registry)
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
${FLUID_CORE_MODULES} tensorrt_engine tensorrt_converter)
......@@ -28,6 +29,8 @@ nv_test(test_trt_concat_op SRCS test_concat_op.cc concat_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine concat_op SERIAL)
nv_test(test_trt_dropout_op SRCS test_dropout_op.cc dropout_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine dropout_op SERIAL)
nv_test(test_trt_pad_op SRCS test_pad_op.cc pad_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine pad_op SERIAL)
nv_test(test_trt_split_op SRCS test_split_op.cc split_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine tensorrt_plugin
split_op concat_op SERIAL)
......@@ -19,7 +19,7 @@ namespace inference {
namespace tensorrt {
/*
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
* ConcatOp
*/
class ConcatOpConverter : public OpConverter {
public:
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* SplitOp.
*/
class SplitOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(40) << "convert a fluid split op to tensorrt split layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
auto input_dims = input->getDimensions();
int input_num = op_desc.Input("X").size();
size_t output_num = op_desc.Output("Out").size();
// Get Attrs
PADDLE_ENFORCE(input_num == 1);
int axis = boost::get<int>(op_desc.GetAttr("axis"));
std::vector<int> output_lengths =
boost::get<std::vector<int>>(op_desc.GetAttr("sections"));
PADDLE_ENFORCE(axis != 0);
if (axis < 0) {
axis += input_dims.nbDims;
} else {
axis -= 1;
}
PADDLE_ENFORCE(output_lengths.size() == output_num);
//
SplitPlugin* plugin = new SplitPlugin(axis, output_lengths);
nvinfer1::IPluginLayer* layer =
engine_->AddPlugin(&input, input_num, plugin);
std::string layer_name = "split (Output: ";
for (size_t i = 0; i < output_num; i++) {
auto output_name = op_desc.Output("Out")[i];
layer->getOutput(i)->setName(output_name.c_str());
engine_->SetITensor(output_name, layer->getOutput(i));
layer_name += output_name;
if (test_mode) {
engine_->DeclareOutput(output_name);
}
}
layer->setName((layer_name + ")").c_str());
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(split, SplitOpConverter);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(split_op, test) {
std::unordered_set<std::string> parameters({""});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("split_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclOutputVar("split_out1", nvinfer1::DimsCHW(2, 2, 2));
validator.DeclOutputVar("split_out2", nvinfer1::DimsCHW(1, 2, 2));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("split");
desc.SetInput("X", {"split_input"});
desc.SetOutput("Out", {"split_out1", "split_out2"});
int num = 0;
int axis = 1;
std::vector<int> output_lengths = {2, 1};
desc.SetAttr("axis", axis);
desc.SetAttr("num", num);
desc.SetAttr("sections", output_lengths);
validator.SetOp(*desc.Proto());
validator.Execute(1);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(split);
......@@ -255,6 +255,12 @@ void TensorRTEngine::freshDeviceId() {
cudaSetDevice(device_);
}
nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
nvinfer1::ITensor *const *inputs, int nbInputs, PluginTensorRT *plugin) {
owned_plugin_.emplace_back(plugin);
return infer_network_.get()->addPluginExt(inputs, nbInputs, *plugin);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
......@@ -125,6 +126,8 @@ class TensorRTEngine : public EngineBase {
void SetRuntimeBatch(size_t batch_size);
int GetRuntimeBatch();
int GetDevice() { return device_; }
nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
int nbInputs, PluginTensorRT*);
// A pointer to CPU memory is needed of the TRT weight.
// Before TRT runs, fluid loads weight into GPU storage.
......@@ -164,8 +167,10 @@ class TensorRTEngine : public EngineBase {
std::unordered_map<std::string /*name*/, size_t /*max size*/> buffer_sizes_;
std::unordered_map<std::string /*name*/, nvinfer1::ITensor* /*ITensor*/>
itensor_map_;
// The specific GPU id that the TensorRTEngine bounded to.
int device_;
std::vector<std::unique_ptr<PluginTensorRT>> owned_plugin_;
// TensorRT related internal members
template <typename T>
......
nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu DEPS enforce)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <cstring>
#include <type_traits>
#include <vector>
template <typename T>
inline void SerializeValue(void** buffer, T const& value);
template <typename T>
inline void DeserializeValue(void const** buffer, size_t* buffer_size,
T* value);
namespace {
template <typename T, class Enable = void>
struct Serializer {};
template <typename T>
struct Serializer<T, typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_enum<T>::value ||
std::is_pod<T>::value>::type> {
static size_t SerializedSize(T const& value) { return sizeof(T); }
static void Serialize(void** buffer, T const& value) {
std::memcpy(*buffer, &value, sizeof(T));
reinterpret_cast<char*&>(*buffer) += sizeof(T);
}
static void Deserialize(void const** buffer, size_t* buffer_size, T* value) {
assert(*buffer_size >= sizeof(T));
std::memcpy(value, *buffer, sizeof(T));
reinterpret_cast<char const*&>(*buffer) += sizeof(T);
*buffer_size -= sizeof(T);
}
};
template <>
struct Serializer<const char*> {
static size_t SerializedSize(const char* value) { return strlen(value) + 1; }
static void Serialize(void** buffer, const char* value) {
std::strcpy(static_cast<char*>(*buffer), value);
reinterpret_cast<char*&>(*buffer) += strlen(value) + 1;
}
static void Deserialize(void const** buffer, size_t* buffer_size,
const char** value) {
*value = static_cast<char const*>(*buffer);
size_t data_size = strnlen(*value, *buffer_size) + 1;
assert(*buffer_size >= data_size);
reinterpret_cast<char const*&>(*buffer) += data_size;
*buffer_size -= data_size;
}
};
template <typename T>
struct Serializer<std::vector<T>,
typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_enum<T>::value ||
std::is_pod<T>::value>::type> {
static size_t SerializedSize(std::vector<T> const& value) {
return sizeof(value.size()) + value.size() * sizeof(T);
}
static void Serialize(void** buffer, std::vector<T> const& value) {
SerializeValue(buffer, value.size());
size_t nbyte = value.size() * sizeof(T);
std::memcpy(*buffer, value.data(), nbyte);
reinterpret_cast<char*&>(*buffer) += nbyte;
}
static void Deserialize(void const** buffer, size_t* buffer_size,
std::vector<T>* value) {
size_t size;
DeserializeValue(buffer, buffer_size, &size);
value->resize(size);
size_t nbyte = value->size() * sizeof(T);
assert(*buffer_size >= nbyte);
std::memcpy(value->data(), *buffer, nbyte);
reinterpret_cast<char const*&>(*buffer) += nbyte;
*buffer_size -= nbyte;
}
};
} // namespace
template <typename T>
inline size_t SerializedSize(T const& value) {
return Serializer<T>::SerializedSize(value);
}
template <typename T>
inline void SerializeValue(void** buffer, T const& value) {
return Serializer<T>::Serialize(buffer, value);
}
template <typename T>
inline void DeserializeValue(void const** buffer, size_t* buffer_size,
T* value) {
return Serializer<T>::Deserialize(buffer, buffer_size, value);
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
nvinfer1::Dims SplitPlugin::getOutputDimensions(int index,
const nvinfer1::Dims* inputDims,
int nbInputs) {
assert(nbInputs == 1);
assert(index < this->getNbOutputs());
nvinfer1::Dims const& input_dims = inputDims[0];
nvinfer1::Dims output_dims = input_dims;
output_dims.d[axis_] = output_length_.at(index);
return output_dims;
}
int SplitPlugin::initialize() {
std::vector<int> segment_offsets(1, 0);
for (int i = 0; i < this->getNbOutputs(); ++i) {
segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
}
segment_offsets_ = segment_offsets;
nvinfer1::Dims dims = this->getInputDims(0);
nx_ = 1;
for (int i = dims.nbDims - 1; i > axis_; --i) {
nx_ *= dims.d[i];
}
ny_ = dims.d[axis_];
nz_ = 1;
for (int i = axis_ - 1; i >= 0; --i) {
nz_ *= dims.d[i];
}
return 0;
}
int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
void** outputs, void* workspace, cudaStream_t stream) {
auto const& input_dims = this->getInputDims(0);
int input_size = 0;
float const* idata = reinterpret_cast<float const*>(inputs[0]);
float** odatas = reinterpret_cast<float**>(outputs);
// kernel impl here.
int inputBatchOffset = nx_ * ny_ * nz_;
for (size_t i = 0; i < this->getNbOutputs(); i++) {
for (size_t j = 0; j < batchSize; j++) {
cudaMemcpyAsync(
odatas[i] +
j * (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ *
sizeof(float),
inputs[0] +
(inputBatchOffset * j + segment_offsets_[i] * nx_) *
sizeof(float),
(segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ * sizeof(float),
cudaMemcpyDeviceToDevice, stream);
}
}
return cudaGetLastError() != cudaSuccess;
}
} // tensorrt
} // inference
} // paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class SplitPlugin : public PluginTensorRT {
int axis_;
std::vector<int> output_length_;
int nx_, ny_, nz_;
std::vector<int> segment_offsets_;
protected:
virtual size_t getSerializationSize() override {
return SerializedSize(axis_) + SerializedSize(output_length_) +
getBaseSerializationSize();
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
virtual void serialize(void *buffer) override {
serializeBase(buffer);
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, output_length_);
}
public:
SplitPlugin(int axis, std::vector<int> const &output_lengths)
: axis_(axis), output_length_(output_lengths) {
assert(axis <= nvinfer1::Dims::MAX_DIMS);
}
// It was used for tensorrt deserialization.
// It should not be called by users.
SplitPlugin(void const *serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
DeserializeValue(&serialData, &serialLength, &axis_);
DeserializeValue(&serialData, &serialLength, &output_length_);
}
SplitPlugin *clone() const override {
return new SplitPlugin(axis_, output_length_);
}
virtual const char *getPluginType() const override { return "split"; }
virtual int getNbOutputs() const override { return output_length_.size(); }
virtual nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims *inputs,
int nbInputDims) override;
virtual int initialize() override;
virtual int enqueue(int batchSize, const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
};
} // tensorrt
} // inference
} // paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
void PluginTensorRT::serializeBase(void*& buffer) {
SerializeValue(&buffer, input_dims_);
SerializeValue(&buffer, max_batch_size_);
SerializeValue(&buffer, data_type_);
SerializeValue(&buffer, data_format_);
}
void PluginTensorRT::deserializeBase(void const*& serialData,
size_t& serialLength) {
DeserializeValue(&serialData, &serialLength, &input_dims_);
DeserializeValue(&serialData, &serialLength, &max_batch_size_);
DeserializeValue(&serialData, &serialLength, &data_type_);
DeserializeValue(&serialData, &serialLength, &data_format_);
}
size_t PluginTensorRT::getBaseSerializationSize() {
return (SerializedSize(input_dims_) + SerializedSize(max_batch_size_) +
SerializedSize(data_type_) + SerializedSize(data_format_));
}
bool PluginTensorRT::supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const {
return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kNCHW));
}
void PluginTensorRT::configureWithFormat(const nvinfer1::Dims* inputDims,
int nbInputs,
const nvinfer1::Dims* outputDims,
int nbOutputs, nvinfer1::DataType type,
nvinfer1::PluginFormat format,
int maxBatchSize) {
data_type_ = type;
data_format_ = format;
input_dims_.assign(inputDims, inputDims + nbInputs);
max_batch_size_ = maxBatchSize;
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <cstring>
#include <iostream>
#include <unordered_map>
#include <vector>
#include "NvInfer.h"
#include "paddle/fluid/inference/tensorrt/plugin/serialize.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class PluginTensorRT : public nvinfer1::IPluginExt {
public:
PluginTensorRT() {}
PluginTensorRT(const void* serialized_data, size_t length) {}
nvinfer1::Dims const& getInputDims(int index) const {
return input_dims_.at(index);
}
size_t getMaxBatchSize() const { return max_batch_size_; }
nvinfer1::DataType getDataType() const { return data_type_; }
nvinfer1::PluginFormat getDataFormat() const { return data_format_; }
virtual const char* getPluginVersion() const { return "1"; }
size_t getWorkspaceSize(int) const override { return 0; }
void terminate() override {}
virtual ~PluginTensorRT() {}
// Check format support. The default is FLOAT32 and NCHW.
bool supportsFormat(nvinfer1::DataType type,
nvinfer1::PluginFormat format) const override;
void configureWithFormat(const nvinfer1::Dims* inputDims, int nbInputs,
const nvinfer1::Dims* outputDims, int nbOutputs,
nvinfer1::DataType type,
nvinfer1::PluginFormat format,
int maxBatchSize) override;
// *NOTE* The following functions need to be overrided in the subclass.
virtual nvinfer1::IPluginExt* clone() const = 0;
virtual const char* getPluginType() const = 0;
// Initialize the layer for execution. This is called when the engine is
// created.
int initialize() override { return 0; }
// Serialize the layer config to buffer.
virtual void serialize(void* buffer) = 0;
virtual size_t getSerializationSize() = 0;
virtual int enqueue(int batchSize, const void* const* inputs, void** outputs,
void* workspace, cudaStream_t stream) = 0;
protected:
// Deserialize input_dims, max_batch_size, data_type, data_format
void deserializeBase(void const*& serialData, size_t& serialLength);
size_t getBaseSerializationSize();
// Serialize input_dims, max_batch_size, data_type, data_format
void serializeBase(void*& buffer);
std::vector<nvinfer1::Dims> input_dims_;
size_t max_batch_size_;
nvinfer1::DataType data_type_;
nvinfer1::PluginFormat data_format_;
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -45,11 +45,7 @@ inference_analysis_api_test(test_analyzer_rnn2 ${RNN2_INSTALL_DIR} analyzer_rnn2
# DAM
set(DAM_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/dam")
download_model_and_data(${DAM_INSTALL_DIR} "DAM_model.tar.gz" "DAM_data.txt.tar.gz")
inference_analysis_test(test_analyzer_dam SRCS analyzer_dam_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS
--infer_model=${DAM_INSTALL_DIR}/model
--infer_data=${DAM_INSTALL_DIR}/data.txt
--use_analysis=0)
inference_analysis_api_test(test_analyzer_dam ${DAM_INSTALL_DIR} analyzer_dam_tester.cc)
# chinese_ner
set(CHINESE_NER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/chinese_ner")
......@@ -108,8 +104,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
if (NOT EXISTS ${TRT_MODEL_INSTALL_DIR})
inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_test_models.tar.gz")
endif()
inference_analysis_test(test_trt_models SRCS trt_models_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} analysis ${analysis_deps} ir_pass_manager analysis_predictor
ARGS --dirname=${TRT_MODEL_INSTALL_DIR}/trt_test_models SERIAL)
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_test_models SERIAL)
endif()
......@@ -69,7 +69,7 @@ struct DataRecord {
num_lines++;
std::vector<std::string> data;
split(line, ',', &data);
CHECK_EQ(data.size(), 2 * MAX_TURN_NUM + 3);
CHECK_EQ(data.size(), (size_t)(2 * MAX_TURN_NUM + 3));
// load turn data
std::vector<int64_t> turns_tmp[MAX_TURN_NUM];
for (int i = 0; i < MAX_TURN_NUM; ++i) {
......@@ -178,7 +178,8 @@ TEST(Analyzer_dam, profile) {
std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all, &outputs, FLAGS_num_threads);
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
PADDLE_ENFORCE_GT(outputs.size(), 0);
......@@ -196,15 +197,13 @@ TEST(Analyzer_dam, fuse_statis) {
contrib::AnalysisConfig cfg;
SetConfig(&cfg);
if (FLAGS_use_analysis) {
int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 317);
EXPECT_EQ(num_ops, 2020);
}
int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 317);
EXPECT_EQ(num_ops, 2020);
}
// Compare result of NativeConfig and AnalysisConfig
......@@ -215,9 +214,8 @@ TEST(Analyzer_dam, compare) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
if (FLAGS_use_analysis) {
CompareNativeAndAnalysis(cfg, input_slots_all);
}
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
} // namespace inference
......
......@@ -133,7 +133,8 @@ TEST(Analyzer_LAC, profile) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all, &outputs, FLAGS_num_threads);
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
// the first inference result
......@@ -175,7 +176,8 @@ TEST(Analyzer_LAC, compare) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareNativeAndAnalysis(cfg, input_slots_all);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
} // namespace analysis
......
......@@ -121,7 +121,8 @@ TEST(Analyzer_Chinese_ner, profile) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all, &outputs, FLAGS_num_threads);
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
// the first inference result
......@@ -160,7 +161,8 @@ TEST(Analyzer_Chinese_ner, compare) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareNativeAndAnalysis(cfg, input_slots_all);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
} // namespace inference
......
......@@ -45,7 +45,8 @@ void profile(bool use_mkldnn = false) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all, &outputs, FLAGS_num_threads);
}
TEST(Analyzer_resnet50, profile) { profile(); }
......@@ -74,7 +75,8 @@ void compare(bool use_mkldnn = false) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareNativeAndAnalysis(cfg, input_slots_all);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
TEST(Analyzer_resnet50, compare) { compare(); }
......
......@@ -233,8 +233,8 @@ TEST(Analyzer_rnn1, profile) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
LOG(INFO) << "to test prediction";
TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all, &outputs, FLAGS_num_threads);
}
// Check the fuse status
......@@ -261,7 +261,8 @@ TEST(Analyzer_rnn1, compare) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareNativeAndAnalysis(cfg, input_slots_all);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
// Test Multi-Thread.
......@@ -272,7 +273,8 @@ TEST(Analyzer_rnn1, multi_thread) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
TestPrediction(cfg, input_slots_all, &outputs, 4 /* multi_thread */);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all, &outputs, 4 /* multi_thread */);
}
// Validate that the AnalysisPredictor + ZeroCopyTensor really works by testing
......
......@@ -132,7 +132,8 @@ TEST(Analyzer_rnn2, profile) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all, &outputs, FLAGS_num_threads);
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
// the first inference result
......@@ -153,7 +154,8 @@ TEST(Analyzer_rnn2, compare) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareNativeAndAnalysis(cfg, input_slots_all);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
} // namespace inference
......
......@@ -161,7 +161,8 @@ TEST(Analyzer_seq_conv1, profile) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all, &outputs, FLAGS_num_threads);
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
// the first inference result
......@@ -199,7 +200,8 @@ TEST(Analyzer_seq_conv1, compare) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareNativeAndAnalysis(cfg, input_slots_all);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
} // namespace inference
......
......@@ -74,7 +74,8 @@ TEST(Analyzer_Text_Classification, profile) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all, &outputs, FLAGS_num_threads);
if (FLAGS_num_threads == 1) {
// Get output
......@@ -101,7 +102,8 @@ TEST(Analyzer_Text_Classification, compare) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareNativeAndAnalysis(cfg, input_slots_all);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
TEST(Analyzer_Text_Classification, compare_against_embedding_fc_lstm_fused) {
......@@ -112,7 +114,8 @@ TEST(Analyzer_Text_Classification, compare_against_embedding_fc_lstm_fused) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareNativeAndAnalysis(cfg, input_slots_all);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
} // namespace inference
......
......@@ -59,9 +59,6 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->specify_input_name = true;
// TODO(TJ): fix fusion gru
cfg->pass_builder()->DeletePass("fc_gru_fuse_pass");
#ifdef PADDLE_WITH_MKLDNN
cfg->EnableMKLDNN();
#endif
}
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
......@@ -94,7 +91,8 @@ void profile(bool use_mkldnn = false) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
TestPrediction(cfg, input_slots_all, &outputs, FLAGS_num_threads);
TestPrediction(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all, &outputs, FLAGS_num_threads);
if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) {
const float ocr_result_data[] = {
......@@ -136,7 +134,8 @@ void compare(bool use_mkldnn = false) {
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareNativeAndAnalysis(cfg, input_slots_all);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}
TEST(Analyzer_vis, compare) { compare(); }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <ostream>
#include <string>
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace paddle {
namespace inference {
thread_local int num_spaces = 0;
static std::string GenSpaces(int num_spaces) {
std::ostringstream os;
for (int i = 0; i < num_spaces; ++i) {
os << " ";
}
return os.str();
}
std::ostream &operator<<(std::ostream &os,
const PaddlePredictor::Config &config) {
os << GenSpaces(num_spaces) << "PaddlePredictor::Config {\n";
num_spaces++;
os << GenSpaces(num_spaces) << "model_dir: " << config.model_dir << "\n";
num_spaces--;
os << GenSpaces(num_spaces) << "}\n";
return os;
}
std::ostream &operator<<(std::ostream &os, const NativeConfig &config) {
os << GenSpaces(num_spaces) << "NativeConfig {\n";
num_spaces++;
os << *reinterpret_cast<const PaddlePredictor::Config *>(&config);
os << GenSpaces(num_spaces) << "use_gpu: " << config.use_gpu << "\n";
os << GenSpaces(num_spaces) << "device: " << config.device << "\n";
os << GenSpaces(num_spaces)
<< "fraction_of_gpu_memory: " << config.fraction_of_gpu_memory << "\n";
os << GenSpaces(num_spaces) << "prog_file: " << config.prog_file << "\n";
os << GenSpaces(num_spaces) << "param_file: " << config.param_file << "\n";
os << GenSpaces(num_spaces)
<< "specify_input_name: " << config.specify_input_name << "\n";
num_spaces--;
os << GenSpaces(num_spaces) << "}\n";
return os;
}
std::ostream &operator<<(std::ostream &os,
const contrib::AnalysisConfig &config) {
os << GenSpaces(num_spaces) << "contrib::AnalysisConfig {\n";
num_spaces++;
os << *reinterpret_cast<const NativeConfig *>(&config);
os << GenSpaces(num_spaces) << "enable_ir_optim: " << config.enable_ir_optim
<< "\n";
os << GenSpaces(num_spaces)
<< "use_feed_fetch_ops: " << config.use_feed_fetch_ops << "\n";
os << GenSpaces(num_spaces) << "use_tensorrt: " << config.use_tensorrt()
<< "\n";
os << GenSpaces(num_spaces) << "use_mkldnn: " << config.use_mkldnn() << "\n";
num_spaces--;
os << GenSpaces(num_spaces) << "}\n";
return os;
}
} // namespace inference
} // namespace paddle
......@@ -19,13 +19,16 @@
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/tests/api/config_printer.h"
#include "paddle/fluid/inference/tests/test_helper.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -38,10 +41,18 @@ DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads.");
DEFINE_bool(use_analysis, true,
"Running the inference program in analysis mode.");
DECLARE_bool(profile);
namespace paddle {
namespace inference {
using contrib::AnalysisConfig;
void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) {
if (use_analysis) {
LOG(INFO) << *reinterpret_cast<const contrib::AnalysisConfig *>(config);
return;
}
LOG(INFO) << *config;
}
void CompareResult(const std::vector<PaddleTensor> &outputs,
const std::vector<PaddleTensor> &ref_outputs) {
......@@ -77,12 +88,13 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
}
std::unique_ptr<PaddlePredictor> CreateTestPredictor(
const AnalysisConfig &config, bool use_analysis = true) {
const PaddlePredictor::Config *config, bool use_analysis = true) {
if (use_analysis) {
return CreatePaddlePredictor<contrib::AnalysisConfig>(config);
} else {
return CreatePaddlePredictor<NativeConfig>(config);
return CreatePaddlePredictor<contrib::AnalysisConfig>(
*(reinterpret_cast<const contrib::AnalysisConfig *>(config)));
}
return CreatePaddlePredictor<NativeConfig>(
*(reinterpret_cast<const NativeConfig *>(config)));
}
size_t GetSize(const PaddleTensor &out) { return VecReduceToInt(out.shape); }
......@@ -111,11 +123,23 @@ std::unordered_map<std::string, int> GetFuseStatis(PaddlePredictor *predictor,
}
void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
const std::string &dirname) {
const std::string &dirname, bool is_combined = true,
std::string model_filename = "model",
std::string params_filename = "params") {
// Set fake_image_data
PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data.");
std::vector<std::vector<int64_t>> feed_target_shapes =
GetFeedTargetShapes(dirname, true, "model", "params");
std::vector<std::vector<int64_t>> feed_target_shapes = GetFeedTargetShapes(
dirname, is_combined, model_filename, params_filename);
std::ostringstream os;
for (size_t i = 0; i < feed_target_shapes.size(); ++i) {
os << "feed target " << i << ": {" << feed_target_shapes[i][0];
for (size_t j = 1; j < feed_target_shapes[i].size(); ++j) {
os << ", " << feed_target_shapes[i][j];
}
os << "}\n";
}
LOG(INFO) << os.str();
int dim1 = feed_target_shapes[0][1];
int dim2 = feed_target_shapes[0][2];
int dim3 = feed_target_shapes[0][3];
......@@ -139,25 +163,43 @@ void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
}
void TestOneThreadPrediction(
const AnalysisConfig &config,
const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, bool use_analysis = true) {
int batch_size = FLAGS_batch_size;
int num_times = FLAGS_repeat;
auto predictor = CreateTestPredictor(config, use_analysis);
Timer timer;
timer.tic();
for (int i = 0; i < num_times; i++) {
for (size_t j = 0; j < inputs.size(); j++) {
predictor->Run(inputs[j], outputs);
// warmup run
LOG(INFO) << "Warm up run...";
{
Timer warmup_timer;
warmup_timer.tic();
predictor->Run(inputs[0], outputs, batch_size);
PrintTime(batch_size, 1, 1, 0, warmup_timer.toc(), 1);
#if !defined(_WIN32)
if (FLAGS_profile) {
paddle::platform::ResetProfiler();
}
#endif
}
LOG(INFO) << "Run " << num_times << " times...";
{
Timer run_timer;
run_timer.tic();
for (int i = 0; i < num_times; i++) {
for (size_t j = 0; j < inputs.size(); j++) {
predictor->Run(inputs[j], outputs, batch_size);
}
}
PrintTime(batch_size, num_times, 1, 0, run_timer.toc() / num_times,
inputs.size());
}
PrintTime(batch_size, num_times, 1, 0, timer.toc() / num_times,
inputs.size());
}
void TestMultiThreadPrediction(
const AnalysisConfig &config,
const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, int num_threads,
bool use_analysis = true) {
......@@ -200,12 +242,11 @@ void TestMultiThreadPrediction(
}
}
void TestPrediction(const AnalysisConfig &config,
void TestPrediction(const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs,
std::vector<PaddleTensor> *outputs, int num_threads,
bool use_analysis = FLAGS_use_analysis) {
LOG(INFO) << "use_analysis: " << use_analysis
<< ", use_mkldnn: " << config.use_mkldnn();
PrintConfig(config, use_analysis);
if (num_threads == 1) {
TestOneThreadPrediction(config, inputs, outputs, use_analysis);
} else {
......@@ -215,9 +256,9 @@ void TestPrediction(const AnalysisConfig &config,
}
void CompareNativeAndAnalysis(
const AnalysisConfig &config,
const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs) {
LOG(INFO) << "use_mkldnn: " << config.use_mkldnn();
PrintConfig(config, true);
std::vector<PaddleTensor> native_outputs, analysis_outputs;
TestOneThreadPrediction(config, inputs, &native_outputs, false);
TestOneThreadPrediction(config, inputs, &analysis_outputs, true);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
namespace paddle {
using paddle::contrib::AnalysisConfig;
DEFINE_string(dirname, "", "Directory of the inference model.");
NativeConfig GetConfigNative() {
NativeConfig config;
config.model_dir = FLAGS_dirname;
// LOG(INFO) << "dirname " << config.model_dir;
config.fraction_of_gpu_memory = 0.15;
config.use_gpu = true;
config.device = 0;
return config;
}
void PrepareTRTConfig(AnalysisConfig *config) {
config->model_dir = FLAGS_dirname + "/" + "mobilenet";
config->fraction_of_gpu_memory = 0.15;
config->EnableTensorRtEngine(1 << 10, 5);
config->pass_builder()->DeletePass("conv_bn_fuse_pass");
config->pass_builder()->DeletePass("fc_fuse_pass");
config->pass_builder()->TurnOnDebug();
namespace inference {
DEFINE_bool(use_tensorrt, true, "Test the performance of TensorRT engine.");
DEFINE_string(prog_filename, "", "Name of model file.");
DEFINE_string(param_filename, "", "Name of parameters file.");
template <typename ConfigType>
void SetConfig(ConfigType* config, std::string model_dir, bool use_gpu,
bool use_tensorrt = false, int batch_size = -1) {
if (!FLAGS_prog_filename.empty() && !FLAGS_param_filename.empty()) {
config->prog_file = model_dir + "/" + FLAGS_prog_filename;
config->param_file = model_dir + "/" + FLAGS_param_filename;
} else {
config->model_dir = model_dir;
}
if (use_gpu) {
config->use_gpu = true;
config->device = 0;
config->fraction_of_gpu_memory = 0.15;
}
}
void PrepareInputs(std::vector<PaddleTensor> *tensors, int batch_size) {
PADDLE_ENFORCE_EQ(tensors->size(), 1UL);
auto &tensor = tensors->front();
int height = 224;
int width = 224;
float *data = new float[batch_size * 3 * height * width];
memset(data, 0, sizeof(float) * (batch_size * 3 * height * width));
data[0] = 1.0f;
// Prepare inputs
tensor.name = "input_0";
tensor.shape = std::vector<int>({batch_size, 3, height, width});
tensor.data = PaddleBuf(static_cast<void *>(data),
sizeof(float) * (batch_size * 3 * height * width));
tensor.dtype = PaddleDType::FLOAT32;
template <>
void SetConfig<contrib::AnalysisConfig>(contrib::AnalysisConfig* config,
std::string model_dir, bool use_gpu,
bool use_tensorrt, int batch_size) {
if (!FLAGS_prog_filename.empty() && !FLAGS_param_filename.empty()) {
config->prog_file = model_dir + "/" + FLAGS_prog_filename;
config->param_file = model_dir + "/" + FLAGS_param_filename;
} else {
config->model_dir = model_dir;
}
if (use_gpu) {
config->use_gpu = true;
config->device = 0;
config->fraction_of_gpu_memory = 0.15;
if (use_tensorrt) {
config->EnableTensorRtEngine(1 << 10, batch_size);
config->pass_builder()->DeletePass("conv_bn_fuse_pass");
config->pass_builder()->DeletePass("fc_fuse_pass");
config->pass_builder()->TurnOnDebug();
} else {
config->enable_ir_optim = true;
}
}
}
void CompareTensorRTWithFluid(int batch_size, std::string model_dirname) {
auto config0 = GetConfigNative();
config0.model_dir = model_dirname;
AnalysisConfig config1(true);
PrepareTRTConfig(&config1);
config1.model_dir = model_dirname;
auto predictor0 = CreatePaddlePredictor<NativeConfig>(config0);
auto predictor1 = CreatePaddlePredictor(config1);
// Prepare inputs
std::vector<PaddleTensor> paddle_tensor_feeds(1);
PrepareInputs(&paddle_tensor_feeds, batch_size);
// Prepare outputs
std::vector<PaddleTensor> outputs0;
std::vector<PaddleTensor> outputs1;
CHECK(predictor0->Run(paddle_tensor_feeds, &outputs0));
CHECK(predictor1->Run(paddle_tensor_feeds, &outputs1, batch_size));
const size_t num_elements = outputs0.front().data.length() / sizeof(float);
const size_t num_elements1 = outputs1.front().data.length() / sizeof(float);
EXPECT_EQ(num_elements, num_elements1);
auto *data0 = static_cast<float *>(outputs0.front().data.data());
auto *data1 = static_cast<float *>(outputs1.front().data.data());
ASSERT_GT(num_elements, 0UL);
for (size_t i = 0; i < std::min(num_elements, num_elements1); i++) {
EXPECT_NEAR(data0[i], data1[i], 1e-3);
void profile(std::string model_dir, bool use_analysis, bool use_tensorrt) {
std::vector<std::vector<PaddleTensor>> inputs_all;
if (!FLAGS_prog_filename.empty() && !FLAGS_param_filename.empty()) {
SetFakeImageInput(&inputs_all, model_dir, true, FLAGS_prog_filename,
FLAGS_param_filename);
} else {
SetFakeImageInput(&inputs_all, model_dir, false, "__model__", "");
}
}
TEST(trt_models_test, mobilenet) {
CompareTensorRTWithFluid(1, FLAGS_dirname + "/" + "mobilenet");
}
TEST(trt_models_test, resnet50) {
CompareTensorRTWithFluid(1, FLAGS_dirname + "/" + "resnet50");
}
TEST(trt_models_test, resnext50) {
CompareTensorRTWithFluid(1, FLAGS_dirname + "/" + "resnext50");
std::vector<PaddleTensor> outputs;
if (use_analysis || use_tensorrt) {
contrib::AnalysisConfig config(true);
SetConfig<contrib::AnalysisConfig>(&config, model_dir, true, use_tensorrt,
FLAGS_batch_size);
TestPrediction(reinterpret_cast<PaddlePredictor::Config*>(&config),
inputs_all, &outputs, FLAGS_num_threads, true);
} else {
NativeConfig config;
SetConfig<NativeConfig>(&config, model_dir, true, false);
TestPrediction(reinterpret_cast<PaddlePredictor::Config*>(&config),
inputs_all, &outputs, FLAGS_num_threads, false);
}
}
TEST(trt_models_test, raw_gpu) {
std::string model_dir = FLAGS_dirname + "/" + "mobilenet";
auto config0 = GetConfigNative();
config0.model_dir = model_dir;
int batch_size = 2;
AnalysisConfig config1(true);
config1.fraction_of_gpu_memory = 0.1;
config1.enable_ir_optim = true;
config1.model_dir = model_dir;
void compare(std::string model_dir, bool use_tensorrt) {
std::vector<std::vector<PaddleTensor>> inputs_all;
if (!FLAGS_prog_filename.empty() && !FLAGS_param_filename.empty()) {
SetFakeImageInput(&inputs_all, model_dir, true, FLAGS_prog_filename,
FLAGS_param_filename);
} else {
SetFakeImageInput(&inputs_all, model_dir, false, "__model__", "");
}
auto predictor0 = CreatePaddlePredictor<NativeConfig>(config0);
auto predictor1 = CreatePaddlePredictor(config1);
std::vector<PaddleTensor> native_outputs;
NativeConfig native_config;
SetConfig<NativeConfig>(&native_config, model_dir, true, false,
FLAGS_batch_size);
TestOneThreadPrediction(
reinterpret_cast<PaddlePredictor::Config*>(&native_config), inputs_all,
&native_outputs, false);
std::vector<PaddleTensor> analysis_outputs;
contrib::AnalysisConfig analysis_config(true);
SetConfig<contrib::AnalysisConfig>(&analysis_config, model_dir, true,
use_tensorrt, FLAGS_batch_size);
TestOneThreadPrediction(
reinterpret_cast<PaddlePredictor::Config*>(&analysis_config), inputs_all,
&analysis_outputs, true);
CompareResult(native_outputs, analysis_outputs);
}
// Prepare inputs
std::vector<PaddleTensor> paddle_tensor_feeds(1);
PrepareInputs(&paddle_tensor_feeds, batch_size);
TEST(TensorRT_mobilenet, compare) {
std::string model_dir = FLAGS_infer_model + "/mobilenet";
compare(model_dir, /* use_tensorrt */ true);
}
// Prepare outputs
std::vector<PaddleTensor> outputs0;
std::vector<PaddleTensor> outputs1;
CHECK(predictor0->Run(paddle_tensor_feeds, &outputs0));
CHECK(predictor1->Run(paddle_tensor_feeds, &outputs1, batch_size));
TEST(TensorRT_resnet50, compare) {
std::string model_dir = FLAGS_infer_model + "/resnet50";
compare(model_dir, /* use_tensorrt */ true);
}
const size_t num_elements = outputs0.front().data.length() / sizeof(float);
const size_t num_elements1 = outputs1.front().data.length() / sizeof(float);
EXPECT_EQ(num_elements, num_elements1);
TEST(TensorRT_resnext50, compare) {
std::string model_dir = FLAGS_infer_model + "/resnext50";
compare(model_dir, /* use_tensorrt */ true);
}
auto *data0 = static_cast<float *>(outputs0.front().data.data());
auto *data1 = static_cast<float *>(outputs1.front().data.data());
TEST(TensorRT_resnext50, profile) {
std::string model_dir = FLAGS_infer_model + "/resnext50";
profile(model_dir, /* use_analysis */ true, FLAGS_use_tensorrt);
}
ASSERT_GT(num_elements, 0UL);
for (size_t i = 0; i < std::min(num_elements, num_elements1); i++) {
EXPECT_NEAR(data0[i], data1[i], 1e-3);
}
TEST(TensorRT_mobilenet, analysis) {
std::string model_dir = FLAGS_infer_model + "/" + "mobilenet";
compare(model_dir, /* use_tensorrt */ false);
}
} // namespace inference
} // namespace paddle
USE_PASS(tensorrt_subgraph_pass);
......@@ -27,11 +27,9 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
"Out(Output) of Fully Connected should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"W(Input) of Fully Connected should not be null.");
// NCHW
auto in_dims = ctx->GetInputDim("Input");
// IO, I=C*H*W
auto w_dims = ctx->GetInputDim("W");
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]});
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
......@@ -44,14 +42,32 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
"The shape of Bias must be [1, dim].");
}
}
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor.");
if (ctx->Attrs().Get<bool>("use_mkldnn")) {
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor.");
}
PADDLE_ENFORCE_EQ(w_dims.size(), 2UL,
"Fully Connected input should be 2-D tensor.");
PADDLE_ENFORCE_EQ(framework::product(in_dims) / in_dims[0], w_dims[0],
"Fully Connected input and weigth size do not match.");
int in_num_col_dims = ctx->Attrs().Get<int>("in_num_col_dims");
PADDLE_ENFORCE_GT(
in_dims.size(), in_num_col_dims,
"The input tensor Input's rank of FCOp should be larger than "
"in_num_col_dims.");
auto in_mat_dims = framework::flatten_to_2d(in_dims, in_num_col_dims);
PADDLE_ENFORCE_EQ(
in_mat_dims[1], w_dims[0],
"Fully Connected input and weigth size do not match. %s, %s");
std::vector<int64_t> output_dims;
output_dims.reserve(static_cast<size_t>(in_num_col_dims + 1));
for (int i = 0; i < in_num_col_dims; ++i) {
output_dims.push_back(in_dims[i]);
}
output_dims.push_back(w_dims[1]);
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("Input", "Out");
}
......@@ -101,12 +117,15 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
}
void FCOpMaker::Make() {
AddInput("Input",
"(Tensor), The input tensor of fully connected operator with format "
"(NCHW). ");
AddInput("Input", "(Tensor), The input tensor of fully connected operator.");
AddInput("W", "(Tensor), The weight fc op with shape (I, O).");
AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O")
.AsDispensable();
AddAttr<int>("in_num_col_dims",
"(int, default 1), The fc op can take tensors with more than "
"two dimensions as its inputs.")
.SetDefault(1)
.EqualGreaterThan(1);
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
......@@ -131,13 +150,15 @@ class FCOpKernel : public framework::OpKernel<T> {
auto output = ctx.Output<Tensor>("Out");
auto in_dims = input->dims();
auto w_dims = w->dims();
auto out_dims = output->dims();
int M = framework::product(out_dims) / out_dims[out_dims.size() - 1];
const T* input_data = input->data<T>();
const T* w_data = w->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
math::FCCompute<platform::CPUDeviceContext, T>(
blas, in_dims[0], w_dims[1], w_dims[0], input_data, w_data, output_data,
blas, M, w_dims[1], w_dims[0], input_data, w_data, output_data,
bias ? bias->data<T>() : NULL);
// TODO(TJ): fuse act
......
......@@ -38,7 +38,7 @@ class HashOp : public framework::OperatorWithKernel {
std::vector<int64_t> out_dims;
out_dims.reserve(dims.size() + 1);
// copy all dims except the last one
for (size_t i = 0u; i != dims.size() - 1; ++i) {
for (int i = 0u; i != dims.size() - 1; ++i) {
out_dims.emplace_back(dims[i]);
}
int num_hash = ctx->Attrs().Get<int>("num_hash");
......
......@@ -118,6 +118,39 @@ void VXXJitCode::generate() {
ret();
}
bool ReluJitCode::init(int d) { return MayIUse(avx); }
void ReluJitCode::generate() {
int offset = 0;
vxorps(ymm_zero, ymm_zero, ymm_zero);
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
vmovups(ymm_src, ptr[param1 + offset]);
vmaxps(ymm_dst, ymm_zero, ymm_src);
vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * AVX_FLOAT_BLOCK;
}
int rest = num_ % AVX_FLOAT_BLOCK;
if (rest >= 4) {
vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src);
vmovups(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 4;
rest -= 4;
}
if (rest >= 2) {
vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src);
vmovq(ptr[param2 + offset], xmm_dst);
offset += sizeof(float) * 2;
rest -= 2;
}
if (rest > 0) {
vmovups(xmm_src, ptr[param1 + offset]);
vmaxps(xmm_dst, xmm_zero, xmm_src);
vmovss(ptr[param2 + offset], xmm_dst);
}
ret();
}
} // namespace gen
} // namespace jitkernel
} // namespace math
......
......@@ -85,6 +85,29 @@ class VXXJitCode : public JitCode {
ymm_t ymm_zero = ymm_t(3);
};
class ReluJitCode : public JitCode {
public:
DECLARE_JIT_CODE(ReluJitCode);
explicit ReluJitCode(int d, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d) {}
static bool init(int d);
void generate() override;
private:
int num_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
xmm_t xmm_zero = xmm_t(0);
xmm_t xmm_src = xmm_t(1);
xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_zero = ymm_t(0);
ymm_t ymm_src = ymm_t(1);
ymm_t ymm_dst = ymm_t(1);
};
} // namespace gen
} // namespace jitkernel
} // namespace math
......
......@@ -97,37 +97,38 @@ class VAddBiasKernel : public Kernel {
template <typename T>
class VActKernel : public Kernel {
public:
virtual void Compute(const T *x, T *y) const = 0;
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T>
class VReluKernel : public VActKernel<T> {
public:
virtual void Compute(const T *x, T *y) const = 0;
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
void (*Compute)(const T *, T *, int);
};
template <typename T>
class VIdentityKernel : public VActKernel<T> {
public:
virtual void Compute(const T *x, T *y) const = 0;
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T>
class VExpKernel : public VActKernel<T> {
public:
virtual void Compute(const T *x, T *y) const = 0;
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T>
class VSigmoidKernel : public VActKernel<T> {
public:
virtual void Compute(const T *x, T *y) const = 0;
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T>
class VTanhKernel : public VActKernel<T> {
public:
virtual void Compute(const T *x, T *y) const = 0;
virtual void ComputeDeprecated(const T *x, T *y) const = 0;
};
template <typename T>
......
......@@ -71,6 +71,13 @@ void VAddBiasRefer(const T* a, const T* x, T* y, int n) {
}
}
template <typename T>
void VReluRefer(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
}
}
#ifdef PADDLE_WITH_MKLML
template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n);
......@@ -344,124 +351,60 @@ bool VAddBiasKernelImpl<float>::useJIT(int d) {
}
#endif
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
/* VRelu JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
template <typename T>
class VReluKernelImpl : public VReluKernel<T> {
public:
explicit VReluKernelImpl(int d) : VReluKernel<T>() { this->num_ = d; }
void Compute(const T* x, T* y) const override {
for (int i = 0; i < this->num_; ++i) {
y[i] = x[i] > 0 ? x[i] : 0;
DECLARE_STATIC_FUNC;
explicit VReluKernelImpl(int d) : VReluKernel<T>() {
this->num_ = d; // TODO(TJ): remove me when ComputeDeprecated done
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 /*init*/ +
d / AVX_FLOAT_BLOCK * 4 /* instructions*/ *
8 /*everage byte for each instruction*/;
jitcode_.reset(new gen::ReluJitCode(d, sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return;
}
}
};
#define INTRI8_FLOAT(isa) \
template <> \
void VReluKernelImpl<float, isa, kEQ8>::Compute(const float* x, float* y) \
const { \
__m256 tmp = _mm256_loadu_ps(x); \
tmp = _mm256_max_ps(tmp, _mm256_setzero_ps()); \
_mm256_storeu_ps(y, tmp); \
}
#define INTRI16_FLOAT(isa) \
template <> \
void VReluKernelImpl<float, isa, kEQ16>::Compute(const float* x, float* y) \
const { \
__m256 zeros = _mm256_setzero_ps(); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + 8); \
tmp0 = _mm256_max_ps(tmp0, zeros); \
tmp1 = _mm256_max_ps(tmp1, zeros); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + 8, tmp1); \
}
#endif
#define INTRI_GT8LT16_FLOAT(isa) \
template <> \
VReluKernelImpl<float, isa, kGT8LT16>::VReluKernelImpl(int d) \
: VReluKernel<float>() { \
this->num_ = d; \
this->end_ = AVX_FLOAT_BLOCK; \
this->rest_ = d - AVX_FLOAT_BLOCK; \
} \
template <> \
void VReluKernelImpl<float, isa, kGT8LT16>::Compute(const float* x, \
float* y) const { \
__m256 zeros = _mm256_setzero_ps(); \
__m256 tmp0 = _mm256_loadu_ps(x); \
__m256 tmp1 = _mm256_loadu_ps(x + this->rest_); \
tmp0 = _mm256_max_ps(tmp0, zeros); \
tmp1 = _mm256_max_ps(tmp1, zeros); \
_mm256_storeu_ps(y, tmp0); \
_mm256_storeu_ps(y + this->rest_, tmp1); \
this->Compute = VReluRefer<T>;
}
#define INTRI_GT16_FLOAT(isa) \
template <> \
VReluKernelImpl<float, isa, kGT16>::VReluKernelImpl(int d) \
: VReluKernel<float>() { \
this->num_ = d; \
this->end_ = d - d % AVX_FLOAT_BLOCK; \
this->rest_ = d - AVX_FLOAT_BLOCK; \
} \
template <> \
void VReluKernelImpl<float, isa, kGT16>::Compute(const float* x, float* y) \
const { \
__m256 zeros = _mm256_setzero_ps(); \
for (int i = 0; i < this->end_; i += AVX_FLOAT_BLOCK) { \
__m256 tmp = _mm256_loadu_ps(x + i); \
tmp = _mm256_max_ps(tmp, zeros); \
_mm256_storeu_ps(y + i, tmp); \
} \
__m256 tmp = _mm256_loadu_ps(x + this->rest_); \
tmp = _mm256_max_ps(tmp, zeros); \
_mm256_storeu_ps(y + this->rest_, tmp); \
void ComputeDeprecated(const T* x, T* y) const override {
VReluRefer(x, y, this->num_);
}
#ifdef PADDLE_WITH_XBYAK
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
INTRI16_FLOAT(jit::avx);
INTRI_GT8LT16_FLOAT(jit::avx);
INTRI_GT16_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
INTRI16_FLOAT(jit::avx2);
INTRI_GT8LT16_FLOAT(jit::avx2);
INTRI_GT16_FLOAT(jit::avx2);
private:
std::unique_ptr<gen::ReluJitCode> jitcode_{nullptr};
#endif
#ifdef __AVX512F__
// TODO(TJ): refine avx512
INTRI8_FLOAT(jit::avx512f);
INTRI16_FLOAT(jit::avx512f);
INTRI_GT8LT16_FLOAT(jit::avx512f);
INTRI_GT16_FLOAT(jit::avx512f);
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VReluKernelImpl<float>::useJIT(int d) {
return gen::ReluJitCode::init(d);
}
#endif
#undef INTRI8_FLOAT
#undef INTRI16_FLOAT
#undef INTRI_GT8LT16_FLOAT
#undef INTRI_GT16_FLOAT
#undef DECLARE_STATIC_FUNC
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
REGISTER_JITKERNEL(vrelu, VReluKernel);
/* An empty JitKernel */
template <typename T, platform::jit::cpu_isa_t isa, jit_block>
class VIdentityKernelImpl : public VIdentityKernel<T> {
public:
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() { this->num_ = d; }
void Compute(const T* x, T* y) const override {}
void ComputeDeprecated(const T* x, T* y) const override {}
};
REGISTER_JITKERNEL_DEPRECATED(vrelu, VReluKernel);
REGISTER_JITKERNEL_DEPRECATED(videntity, VIdentityKernel);
} // namespace jitkernel
......
......@@ -175,26 +175,26 @@ class LSTMKernelImpl : public LSTMKernel<T> {
void ComputeCtHt(T* gates, const T* ct_1, T* ct, T* ht, const T* wp_data,
T* checked) const override {
// gates: W_ch, W_ih, W_fh, W_oh
act_gate_d3_->Compute(gates + d_, gates + d_);
act_gate_d3_->ComputeDeprecated(gates + d_, gates + d_);
/* C_t = C_t-1 * fgated + cand_gated * igated */
act_cand_d_->Compute(gates, gates);
act_cand_d_->ComputeDeprecated(gates, gates);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_);
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/
act_gate_d_->Compute(gates + d_, gates + d_);
act_cand_d_->Compute(gates, gates);
act_gate_d_->ComputeDeprecated(gates + d_, gates + d_);
act_cand_d_->ComputeDeprecated(gates, gates);
vmul_d_->Compute(gates, gates + d_, ct, d_);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_);
act_cell_d_->Compute(ct, gates + d2_);
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
......@@ -292,32 +292,32 @@ class PeepholeKernelImpl : public LSTMKernel<T> {
vmul_d_->Compute(wp_data, ct_1, checked, d_);
vmul_d_->Compute(wp_data + d_, ct_1, checked + d_, d_);
vadd_d2_->Compute(checked, gates + d_, gates + d_, d2_);
act_gate_d2_->Compute(gates + d_, gates + d_);
act_gate_d2_->ComputeDeprecated(gates + d_, gates + d_);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
act_cand_d_->Compute(gates, gates);
act_cand_d_->ComputeDeprecated(gates, gates);
vmul_d_->Compute(gates, gates + d_, gates + d_, d_);
vmul_d_->Compute(ct_1, gates + d2_, gates + d2_, d_);
vadd_d_->Compute(gates + d_, gates + d2_, ct, d_);
/* get ogated*/
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
act_gate_d_->Compute(gates + d3_, gates + d3_);
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
/* H_t = act_cell(C_t) * ogated */
act_cell_d_->Compute(ct, gates + d2_);
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
void ComputeC1H1(T* gates, T* ct, T* ht, const T* wp_data) const override {
/* C_t = igated * cgated*/
act_gate_d_->Compute(gates + d_, gates + d_);
act_cand_d_->Compute(gates, gates);
act_gate_d_->ComputeDeprecated(gates + d_, gates + d_);
act_cand_d_->ComputeDeprecated(gates, gates);
vmul_d_->Compute(gates, gates + d_, ct, d_);
/* get outgated, put W_oc * C_t on igated */
vmul_d_->Compute(wp_data + d2_, ct, gates + d_, d_);
vadd_d_->Compute(gates + d_, gates + d3_, gates + d3_, d_);
/* H_t = act_cell(C_t) * ogated */
act_gate_d_->Compute(gates + d3_, gates + d3_);
act_cell_d_->Compute(ct, gates + d2_);
act_gate_d_->ComputeDeprecated(gates + d3_, gates + d3_);
act_cell_d_->ComputeDeprecated(ct, gates + d2_);
vmul_d_->Compute(gates + d2_, gates + d3_, ht, d_);
}
......@@ -376,20 +376,20 @@ class GRUKernelImpl : public GRUKernel<T> {
}
void ComputeH1(T* gates, T* ht) const override {
act_gate_d_->Compute(gates, gates);
act_state_d_->Compute(gates + d2_, gates + d2_);
act_gate_d_->ComputeDeprecated(gates, gates);
act_state_d_->ComputeDeprecated(gates + d2_, gates + d2_);
vmul_d_->Compute(gates, gates + d2_, ht, d_);
}
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
// W: {W_update, W_reset; W_state}
act_gate_d2_->Compute(gates, gates);
act_gate_d2_->ComputeDeprecated(gates, gates);
vmul_d_->Compute(ht_1, gates + d_, ht, d_);
}
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
T* y = gates + d2_;
act_state_d_->Compute(y, y);
act_state_d_->ComputeDeprecated(y, y);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d_; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
......
......@@ -92,7 +92,7 @@ TEST(JitKernel, vrelu) {
#endif
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, ztgt_data);
ker->Compute(x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
VLOG(30) << "Vec size " << d
......@@ -181,7 +181,7 @@ TEST(JitKernel, vexp) {
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, ztgt_data);
ker->ComputeDeprecated(x_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
......@@ -222,7 +222,7 @@ void vsigmoid_better(
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = 0.f - y[i];
}
vexp->Compute(y, y);
vexp->ComputeDeprecated(y, y);
for (int i = 0; i < n; ++i) {
y[i] = 1.f / (1.f + y[i]);
}
......@@ -253,7 +253,7 @@ TEST(JitKernel, vsigmoid) {
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, ztgt_data);
ker->ComputeDeprecated(x_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
......@@ -287,7 +287,7 @@ void vtanh_better(
const int n, const float* x, float* y) {
const float a = 2.f, b = -1.f;
vscal->Compute(&a, x, y, n);
vsigmoid->Compute(y, y);
vsigmoid->ComputeDeprecated(y, y);
vscal->Compute(&a, y, y, n);
vaddbias->Compute(&b, y, y, n);
}
......@@ -321,7 +321,7 @@ TEST(JitKernel, vtanh) {
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, ztgt_data);
ker->ComputeDeprecated(x_data, ztgt_data);
}
auto ttgte = GetCurrentUS();
......@@ -344,8 +344,8 @@ void lstm_ctht_ref(
const std::shared_ptr<
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp_1,
const int d, float* gates, const float* ct_1, float* ct, float* ht) {
vsigmoid_3d->Compute(gates + d, gates + d);
vtanh_d->Compute(gates, gates);
vsigmoid_3d->ComputeDeprecated(gates + d, gates + d);
vtanh_d->ComputeDeprecated(gates, gates);
const float *i = gates + d, *f = gates + d * 2, *o = gates + d * 3;
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
......@@ -355,7 +355,7 @@ void lstm_ctht_ref(
// H_t = act_cell(C_t) * ogated
float tmp = ct[k] * 2;
tmp = 0.f - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
vexp_1->Compute(&tmp, &tmp);
vexp_1->ComputeDeprecated(&tmp, &tmp);
tmp = 2.f / (1.f + tmp) - 1.f;
ht[k] = tmp * o[k];
}
......@@ -373,13 +373,13 @@ void lstm_ctht_better(
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd_d,
const int d, float* gates, const float* ct_1, float* ct, float* ht) {
int d2 = d * 2;
vsigmoid_3d->Compute(gates + d, gates + d);
vtanh_d->Compute(gates, gates);
vsigmoid_3d->ComputeDeprecated(gates + d, gates + d);
vtanh_d->ComputeDeprecated(gates, gates);
vmul_d->Compute(gates, gates + d, gates + d, d);
vmul_d->Compute(ct_1, gates + d2, gates + d2, d);
vadd_d->Compute(gates + d, gates + d2, ct, d);
/* H_t = act_cell(C_t) * ogated */
vtanh_d->Compute(ct, gates + d2);
vtanh_d->ComputeDeprecated(ct, gates + d2);
vmul_d->Compute(gates + d2, gates + d * 3, ht, d);
}
......@@ -736,7 +736,7 @@ void vaddrelu_better(
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
const float* x, const float* y, float* z, int d) {
vadd->Compute(x, y, z, d);
vrelu->Compute(z, z);
vrelu->ComputeDeprecated(z, z);
}
TEST(JitKernel, vaddrelu) {
......
......@@ -244,7 +244,7 @@ typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
size_t data_len, const T* in, T* out) {
for (int64_t i = 0; i < data_len; i++) {
for (size_t i = 0; i < data_len; i++) {
out[i] += in[i];
}
}
......
......@@ -70,11 +70,11 @@ void TestSequencePoolingSum(const paddle::framework::LoD& lod) {
EXPECT_EQ(in_grad.lod(), lod);
if (paddle::platform::is_cpu_place(*place)) {
for (int64_t i = 0; i < in_grad.lod()[0].size() - 1; ++i) {
for (size_t i = 0; i < in_grad.lod()[0].size() - 1; ++i) {
int64_t begin = in_grad.lod()[0][i];
int64_t end = in_grad.lod()[0][i + 1];
paddle::framework::Tensor tmp = in_grad.Slice(begin, end);
for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) {
for (size_t j = 0; j != tmp.numel() / second_dim; ++j) {
for (int64_t m = 0; m != second_dim; ++m) {
EXPECT_EQ(tmp.data<T>()[m + j * second_dim],
out_grad.data<T>()[m + i * second_dim]);
......@@ -82,11 +82,11 @@ void TestSequencePoolingSum(const paddle::framework::LoD& lod) {
}
}
} else {
for (int64_t i = 0; i < cpu_in_grad.lod()[0].size() - 1; ++i) {
for (size_t i = 0; i < cpu_in_grad.lod()[0].size() - 1; ++i) {
int64_t begin = cpu_in_grad.lod()[0][i];
int64_t end = cpu_in_grad.lod()[0][i + 1];
paddle::framework::Tensor tmp = cpu_in_grad.Slice(begin, end);
for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) {
for (size_t j = 0; j != tmp.numel() / second_dim; ++j) {
for (int64_t m = 0; m != second_dim; ++m) {
EXPECT_EQ(tmp.data<T>()[m + j * second_dim],
cpu_out_grad.data<T>()[m + i * second_dim]);
......
......@@ -19,8 +19,10 @@ namespace paddle {
namespace operators {
namespace math {
template class SoftmaxFunctor<platform::CPUDeviceContext, float>;
template class SoftmaxFunctor<platform::CPUDeviceContext, double>;
template class SoftmaxFunctor<platform::CPUDeviceContext, float, true>;
template class SoftmaxFunctor<platform::CPUDeviceContext, float, false>;
template class SoftmaxFunctor<platform::CPUDeviceContext, double, true>;
template class SoftmaxFunctor<platform::CPUDeviceContext, double, false>;
template class SoftmaxGradFunctor<platform::CPUDeviceContext, float>;
template class SoftmaxGradFunctor<platform::CPUDeviceContext, double>;
......
......@@ -98,9 +98,14 @@ template class SoftmaxGradCUDNNFunctor<float>;
template class SoftmaxGradCUDNNFunctor<double>;
template class SoftmaxGradCUDNNFunctor<platform::float16>;
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16>;
template class SoftmaxFunctor<platform::CUDADeviceContext, float>;
template class SoftmaxFunctor<platform::CUDADeviceContext, double>;
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16,
false>;
template class SoftmaxFunctor<platform::CUDADeviceContext, platform::float16,
true>;
template class SoftmaxFunctor<platform::CUDADeviceContext, float, false>;
template class SoftmaxFunctor<platform::CUDADeviceContext, double, false>;
template class SoftmaxFunctor<platform::CUDADeviceContext, float, true>;
template class SoftmaxFunctor<platform::CUDADeviceContext, double, true>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext, float>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext, double>;
template class SoftmaxGradFunctor<platform::CUDADeviceContext,
......
......@@ -19,7 +19,7 @@ namespace paddle {
namespace operators {
namespace math {
template <typename DeviceContext, typename T>
template <typename DeviceContext, typename T, bool is_test>
class SoftmaxFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor* X,
......
......@@ -32,10 +32,10 @@ struct ValueClip {
}
};
template <typename DeviceContext, typename T>
void SoftmaxFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
const framework::Tensor* X,
framework::Tensor* Y) {
template <typename DeviceContext, typename T, bool is_test>
void SoftmaxFunctor<DeviceContext, T, is_test>::operator()(
const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
......@@ -65,6 +65,39 @@ void SoftmaxFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
.broadcast(one_by_class));
}
template <typename DeviceContext, typename T>
class SoftmaxFunctor<DeviceContext, T, true> {
void operator()(const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
const int kBatchDim = 0;
const int kClassDim = 1;
const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto shifted_logits = (logits -
logits.maximum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
softmax.device(*context.eigen_device()) = shifted_logits.exp();
softmax.device(*context.eigen_device()) = (softmax *
softmax.sum(along_class)
.inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));
}
};
template <typename DeviceContext, typename T>
void SoftmaxGradFunctor<DeviceContext, T>::operator()(
const DeviceContext& context, const framework::Tensor* y,
......
......@@ -43,11 +43,11 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(ids.size(), outs.size(),
"the number of Ids and Out should be the same");
int row_ids_size = 0;
size_t row_ids_size = 0;
int row_size = 0;
int embedding_size = 0;
for (int i = 0; i < x_tensors.size(); ++i) {
for (size_t i = 0; i < x_tensors.size(); ++i) {
const auto *x_tensor = x_tensors[i];
const auto *row_id = row_ids[i];
......@@ -66,7 +66,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
std::unordered_map<int64_t, std::tuple<int64_t, int64_t>>
selected_rows_idx_map;
for (int i = 0; i < x_tensors.size(); ++i) {
for (size_t i = 0; i < x_tensors.size(); ++i) {
const auto *row_id = row_ids[i];
for (int j = 0; j < row_id->numel(); ++j) {
......@@ -78,7 +78,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(row_ids_size, selected_rows_idx_map.size(),
"the rows and tensor map size should be the same");
for (int i = 0; i < outs.size(); ++i) {
for (size_t i = 0; i < outs.size(); ++i) {
auto *out_ids = ids[i];
auto *out = outs[i];
......
......@@ -38,7 +38,7 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> {
} else {
trainer_id = *trainer_id_data;
}
PADDLE_ENFORCE_LT(trainer_id, in_list.size());
PADDLE_ENFORCE_LT((size_t)trainer_id, in_list.size());
out->mutable_data<T>(context.GetPlace());
out->ShareDataWith(*(in_list[trainer_id]));
}
......
......@@ -35,8 +35,13 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1);
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
math::SoftmaxFunctor<DeviceContext, T>()(
#ifdef ON_INFER
math::SoftmaxFunctor<DeviceContext, T, true>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
#else
math::SoftmaxFunctor<DeviceContext, T, false>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
#endif
}
};
......
......@@ -42,8 +42,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SoftmaxFunctor<platform::CPUDeviceContext, T>()(dev_ctx, logits,
softmax);
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
dev_ctx, logits, softmax);
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
context.Attr<int>("ignore_index"));
......
......@@ -64,7 +64,7 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
out_ids.resize(outs.size());
// split id by their shard_num.
for (int i = 0; i < all_ids.size(); ++i) {
for (size_t i = 0; i < all_ids.size(); ++i) {
T id = all_ids[i];
size_t shard_id = static_cast<size_t>(id) % shard_num;
out_ids[shard_id].push_back(id);
......
......@@ -112,11 +112,11 @@ def __bootstrap__():
os.environ['OMP_NUM_THREADS'] = str(num_threads)
read_env_flags = [
'use_pinned_memory', 'check_nan_inf', 'benchmark',
'eager_delete_scope', 'use_mkldnn', 'use_ngraph',
'initial_cpu_memory_in_mb', 'init_allocated_mem', 'free_idle_memory',
'paddle_num_threads', 'dist_threadpool_size',
'eager_delete_tensor_gb', 'reader_queue_speed_test_mode'
'use_pinned_memory', 'check_nan_inf', 'benchmark', 'eager_delete_scope',
'use_mkldnn', 'use_ngraph', 'initial_cpu_memory_in_mb',
'init_allocated_mem', 'free_idle_memory', 'paddle_num_threads',
'dist_threadpool_size', 'eager_delete_tensor_gb',
'reader_queue_speed_test_mode'
]
if os.name != 'nt':
read_env_flags.append('warpctc_dir')
......
......@@ -26,6 +26,7 @@ from multiprocessing import Process
from functools import reduce
import numpy as np
import pickle
import unittest
import six
......@@ -166,7 +167,10 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
io.save_persistables(startup_exe, model_dir, trainer_prog)
var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor())
print(np.ravel(var).tolist())
if six.PY2:
print(pickle.dumps(np.ravel(var).tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist()))
if __name__ == "__main__":
......
......@@ -65,14 +65,14 @@ class TestDistSaveLoadDense2x2(TestDistBase):
shutil.rmtree(model_dir)
local_np = np.array(eval(local_var[0]))
train0_np = np.array(eval(tr0_var[0]))
train1_np = np.array(eval(tr1_var[0]))
local_np = np.array(local_var)
train0_np = np.array(tr0_var)
train1_np = np.array(tr1_var)
self.assertAlmostEqual(local_np.all(), train0_np.all(), delta=delta)
self.assertAlmostEqual(local_np.all(), train1_np.all(), delta=delta)
self.assertAlmostEqual(train0_np.all(), train1_np.all(), delta=delta)
@unittest.skip(reason="CI fail")
def test_dist(self):
need_envs = {
"IS_DISTRIBUTED": '0',
......
requests==2.9.2
numpy>=1.12,<=1.14 #TODO:change to ">=1.12" when numpy fix bug in 1.15 and higher version
protobuf==3.1
recordio>=0.1.0; sys_platform != 'win32'
recordio>=0.1.0
matplotlib==2.2.3 # TODO: let python3 paddlepaddle package use latest matplotlib
rarfile
scipy>=0.19.0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册