提交 6332bd1e 编写于 作者: L Luo Tao

Merge branch 'develop' into infer_mkl

......@@ -6,32 +6,32 @@ PaddlePaddle provides the users the ability to flexibly set various command line
.. toctree::
:maxdepth: 1
cmd_parameter/index_cn.rst
cmd_parameter/index_en.rst
PaddlePaddle supports distributed training tasks on fabric clusters, MPI clusters, and Kubernetes clusters. For detailed configuration and usage instructions, refer to:
.. toctree::
:maxdepth: 1
cluster/index_cn.rst
cluster/index_en.rst
PaddlePaddle provides a C-API for inference. We provide the following guidelines for using the C-API:
.. toctree::
:maxdepth: 1
capi/index_cn.rst
capi/index_en.rst
PaddlePaddle supports a variety of flexible and efficient recurrent neural networks. For details, please refer to:
.. toctree::
:maxdepth: 1
rnn/index_cn.rst
rnn/index_en.rst
How to use the built-in timing tool, nvprof, or nvvp to run performance analysis and tuning, please refer to:
.. toctree::
:maxdepth: 1
optimization/gpu_profiling_cn.rst
optimization/gpu_profiling_en.rst
......@@ -117,10 +117,10 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) {
if (holder_ != nullptr) {
holder_->set_type(type);
}
PADDLE_ENFORCE_GT(
numel(), 0,
"When calling this method, the Tensor's numel must be larger than zero. "
"Please check Tensor::Resize has been called first.");
PADDLE_ENFORCE_GE(numel(), 0,
"When calling this method, the Tensor's numel must be "
"equal or larger than zero. "
"Please check Tensor::Resize has been called first.");
int64_t size = numel() * SizeOfType(type);
/* some versions of boost::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) ||
......
......@@ -59,7 +59,7 @@ TEST(BuddyAllocator, CPUMultAlloc) {
EXPECT_EQ(total_size, 0UL);
for (auto size :
{128, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304}) {
{0, 128, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304}) {
ps[paddle::memory::Alloc(cpu, size)] = size;
// Buddy Allocator doesn't manage too large memory chunk
......@@ -117,7 +117,7 @@ TEST(BuddyAllocator, GPUMultAlloc) {
EXPECT_EQ(total_size, 0UL);
for (auto size :
{128, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304}) {
{0, 128, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304}) {
ps[paddle::memory::Alloc(gpu, size)] = size;
// Buddy Allocator doesn't manage too large memory chunk
......
......@@ -153,7 +153,12 @@ function(op_library TARGET)
# pybind USE_OP_DEVICE_KERNEL for MKLDNN
if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
# Append first implemented MKLDNN activation operator
if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
endif()
endif()
# pybind USE_OP
......@@ -182,9 +187,13 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(recv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(listen_and_serv_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(listen_and_serv_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS send_op listen_and_serv_op sum_op executor)
else()
set(DEPS_OPS ${DEPS_OPS} send_op recv_op listen_and_serv_op)
set(DEPS_OPS ${DEPS_OPS} send_op recv_op listen_and_serv_op send_vars_op send_barrier_op)
endif()
op_library(cond_op DEPS framework_proto tensor net_op)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "mkldnn_activation_op.h"
#include "paddle/fluid/operators/activation_op.h"
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
namespace {
template <typename T, typename ExecContext>
void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
const T alpha = 0, const T beta = 0) {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
// get buffers
const auto *src = ctx.template Input<Tensor>("X");
const auto *src_data = src->template data<T>();
auto *dst = ctx.template Output<Tensor>("Out");
const T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());
// get memory dim
PADDLE_ENFORCE(src->dims().size() == 4,
"Input dim must be with 4, i.e. NCHW");
std::vector<int> src_tz = framework::vectorize2int(src->dims());
// create memory description
// TODO(kbinias-intel): support more formats
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw);
// create memory primitives
auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src_data);
auto dst_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)dst_data);
auto forward_desc = mkldnn::eltwise_forward::desc(
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
// save prim desc into global device context to be referred in backward path
const std::string key = ctx.op().Output("Out");
const std::string key_eltwise_pd = key + "@eltwise_pd";
auto forward_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
forward_desc, mkldnn_engine);
dev_ctx.SetBlob(key_eltwise_pd, forward_pd);
auto eltwise = mkldnn::eltwise_forward(*forward_pd, src_memory, dst_memory);
// push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline = {eltwise};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
template <typename T, typename ExecContext>
void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
const T alpha = 0, const T beta = 0) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine();
// get buffers
const auto *x = ctx.template Input<Tensor>("X");
const auto *src = x->template data<T>();
auto *dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
const auto *diff_dst = dout->template data<T>();
auto *dx =
ctx.template Output<framework::Tensor>(framework::GradVarName("X"));
const T *diff_src = dx->template mutable_data<T>(ctx.GetPlace());
// get memory dim
std::vector<int> src_tz = framework::vectorize2int(x->dims());
// create memory description
auto data_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw);
// create memory primitives
auto src_memory = mkldnn::memory({data_md, mkldnn_engine}, (void *)src);
auto diff_src_memory =
mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_src);
auto diff_dst_memory =
mkldnn::memory({data_md, mkldnn_engine}, (void *)diff_dst);
auto backward_desc =
mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta);
// retrieve eltwise primitive desc from device context
const std::string key = ctx.op().Input("Out");
const std::string key_eltwise_pd = key + "@eltwise_pd";
const std::shared_ptr<void> forward_pd = dev_ctx.GetBlob(key_eltwise_pd);
PADDLE_ENFORCE(forward_pd != nullptr,
"Fail to find eltwise_pd in device context");
auto *p_forward_pd =
static_cast<mkldnn::eltwise_forward::primitive_desc *>(forward_pd.get());
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
backward_desc, mkldnn_engine, *p_forward_pd);
auto eltwise_bwd = mkldnn::eltwise_backward(eltwise_bwd_prim_desc, src_memory,
diff_dst_memory, diff_src_memory);
// push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline = {eltwise_bwd};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
} // anonymous namespace
template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
template <typename ExecContext>
void operator()(const ExecContext &ctx) const {
eltwise_forward<T>(ctx, algorithm);
}
};
template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
template <typename ExecContext>
void operator()(const ExecContext &ctx) const {
eltwise_grad<T>(ctx, algorithm);
}
};
template <typename T>
using ReluMkldnnFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_relu>;
template <typename T>
using TanhMkldnnFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_tanh>;
template <typename T>
using SqrtMkldnnFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_sqrt>;
template <typename T>
using AbsMkldnnFunctor =
MKLDNNActivationFunc<T, mkldnn::algorithm::eltwise_abs>;
template <typename T>
using ReluMkldnnGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_relu>;
template <typename T>
using TanhMkldnnGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_tanh>;
template <typename T>
using SqrtMkldnnGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_sqrt>;
template <typename T>
using AbsMkldnnGradFunctor =
MKLDNNActivationGradFunc<T, mkldnn::algorithm::eltwise_abs>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationKernel<ops::functor<float>>); \
REGISTER_OP_KERNEL( \
act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
__macro(relu, ReluMkldnnFunctor, ReluMkldnnGradFunctor); \
__macro(tanh, TanhMkldnnFunctor, TanhMkldnnGradFunctor); \
__macro(sqrt, SqrtMkldnnFunctor, SqrtMkldnnGradFunctor); \
__macro(abs, AbsMkldnnFunctor, AbsMkldnnGradFunctor);
FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/mkldnn_activation_op.h"
namespace paddle {
namespace operators {
......@@ -87,6 +88,9 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu operator");
AddOutput("Out", "Output of Relu operator");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Relu Activation Operator.
......@@ -140,6 +144,9 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Tanh operator");
AddOutput("Out", "Output of Tanh operator");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Tanh Activation Operator.
......@@ -193,6 +200,9 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sqrt operator");
AddOutput("Out", "Output of Sqrt operator");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Sqrt Activation Operator.
......@@ -208,6 +218,9 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Abs operator");
AddOutput("Out", "Output of Abs operator");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
Abs Activation Operator.
......@@ -524,11 +537,11 @@ REGISTER_OP(logsigmoid, ops::ActivationOp, ops::LogSigmoidOpMaker,
REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad,
ops::ActivationOpGrad);
REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad,
ops::ActivationOpGrad);
REGISTER_OP(relu, ops::ActivationWithMKLDNNOp, ops::ReluOpMaker, relu_grad,
ops::ActivationWithMKLDNNOpGrad);
REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad,
ops::ActivationOpGrad);
REGISTER_OP(tanh, ops::ActivationWithMKLDNNOp, ops::TanhOpMaker, tanh_grad,
ops::ActivationWithMKLDNNOpGrad);
REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker,
tanh_shrink_grad, ops::ActivationOpGrad);
......@@ -536,11 +549,11 @@ REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker,
REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker,
softshrink_grad, ops::ActivationOpGrad);
REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad,
ops::ActivationOpGrad);
REGISTER_OP(sqrt, ops::ActivationWithMKLDNNOp, ops::SqrtOpMaker, sqrt_grad,
ops::ActivationWithMKLDNNOpGrad);
REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad,
ops::ActivationOpGrad);
REGISTER_OP(abs, ops::ActivationWithMKLDNNOp, ops::AbsOpMaker, abs_grad,
ops::ActivationWithMKLDNNOpGrad);
REGISTER_OP(ceil, ops::ActivationOp, ops::CeilOpMaker, ceil_grad,
ops::ActivationOpGrad);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -17,6 +17,10 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
......
......@@ -49,9 +49,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
s->Prepare(var_h, time_out);
s->response_call_back_ = NULL;
auto call = std::move(s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req,
&cq_));
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s);
});
......@@ -107,8 +106,8 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
auto call = std::move(s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_));
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, (void*)s);
});
......
......@@ -22,9 +22,10 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/simple_block_queue.h"
namespace paddle {
namespace operators {
......
......@@ -199,9 +199,9 @@ TEST(LodTensor, Run) {
RunTestLodTensor(place);
RunTestLodTensor(place, 1);
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace place;
RunTestLodTensor(place);
RunTestLodTensor(place, 1);
platform::CUDAPlace gpu(0);
RunTestLodTensor(gpu);
RunTestLodTensor(gpu, 1);
#endif
}
......@@ -210,7 +210,7 @@ TEST(SelectedRows, Run) {
RunSerdeTestSelectedRows(place);
#ifdef PADDLE_WITH_CUDA
platform::CUDAPlace place;
RunSerdeTestSelectedRows(place);
platform::CUDAPlace gpu;
RunSerdeTestSelectedRows(gpu);
#endif
}
......@@ -33,6 +33,7 @@ __global__ void RandomGenerator(const size_t n, const int seed,
int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < n; idx += blockDim.x * gridDim.x) {
rng.discard(idx);
if (dist(rng) < dropout_prob) {
mask_data[idx] = static_cast<T>(0);
} else {
......
......@@ -93,12 +93,6 @@ class ListenAndServOp : public framework::OperatorBase {
"server program should have at least 2 blocks");
framework::Executor executor(dev_place);
std::vector<framework::ExecutorPrepareContext *> blk_ctx_list;
blk_ctx_list.push_back(nullptr); // block0 is not used.
for (int blkid = 1; blkid < num_blocks; ++blkid) {
auto *exe_ctx = executor.Prepare(*program, blkid);
blk_ctx_list.push_back(exe_ctx);
}
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
......@@ -149,12 +143,11 @@ class ListenAndServOp : public framework::OperatorBase {
std::vector<std::future<void>> fs;
// block0 contains only listen_and_serv op, start run from block1.
for (int blkid = 1; blkid < num_blocks - 1; ++blkid) {
fs.push_back(framework::Async(
[&executor, &program, &recv_scope, &blk_ctx_list, blkid]() {
fs.push_back(
framework::Async([&executor, &program, &recv_scope, blkid]() {
int run_block = blkid; // thread local
try {
executor.RunPreparedContext(blk_ctx_list[run_block],
&recv_scope, false, false);
executor.Run(*program, &recv_scope, run_block, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
......@@ -164,8 +157,7 @@ class ListenAndServOp : public framework::OperatorBase {
// Run global block at final step, or block1 if there are only 2 blocks
if (num_blocks >= 2) {
try {
executor.RunPreparedContext(blk_ctx_list[num_blocks - 1], &recv_scope,
false, false);
executor.Run(*program, &recv_scope, num_blocks - 1, false, false);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
......@@ -185,9 +177,9 @@ class ListenAndServOp : public framework::OperatorBase {
sparse_vars.clear();
} // while(true)
for (int i = 0; i < num_blocks; ++i) {
delete blk_ctx_list[i];
}
// for (int i = 0; i < num_blocks; ++i) {
// delete blk_ctx_list[i];
// }
}
protected:
......
......@@ -20,7 +20,7 @@ namespace math {
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
* each dimension must be the same, except the axis dimension.
*/
template <typename T>
class ConcatFunctor<platform::CPUDeviceContext, T> {
......@@ -63,7 +63,7 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
* each dimension must be the same, except the axis dimension.
*/
template <typename T>
class ConcatGradFunctor<platform::CPUDeviceContext, T> {
......
......@@ -66,68 +66,66 @@ __global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
}
template <typename T>
__global__ void KernelConcat(T** inputs, const int input_col,
const int output_rows, const int output_cols,
T* output) {
__global__ void KernelConcat(T** inputs_data, const int fixed_in_col,
const int out_rows, const int out_cols,
T* output_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
double inv_input_col = 1.0 / input_col;
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col;
int in_offset = tid_x - split * input_col;
T* input_ptr = inputs[split];
for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * 1.0 / fixed_in_col;
int in_offset = tid_x - split * fixed_in_col;
T* input_ptr = inputs_data[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) {
output[tid_y * output_cols + tid_x] =
input_ptr[tid_y * input_col + in_offset];
for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) {
output_data[tid_y * out_cols + tid_x] =
input_ptr[tid_y * fixed_in_col + in_offset];
}
}
}
template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row,
const int input_col, const int* output_cols,
int col_size, T** outputs) {
__global__ void KernelConcatGrad(const T* input_data, const int in_row,
const int in_col, const int* out_cols,
int out_cols_size, T** outputs_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int segment = upper_bound<int>(output_cols, col_size, tid_x) - 1;
int curr_offset = output_cols[segment];
int segment = upper_bound<int>(out_cols, out_cols_size, tid_x) - 1;
int curr_offset = out_cols[segment];
int curr_segment = segment;
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
T curr_col_offset;
while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) {
while ((curr_col_offset = out_cols[curr_segment + 1]) <= tid_x) {
curr_offset = curr_col_offset;
++curr_segment;
}
int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset;
T* output_ptr = outputs[curr_segment];
T* output_ptr = outputs_data[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * segment_width + local_col] =
input[tid_y * input_col + tid_x];
input_data[tid_y * in_col + tid_x];
}
}
template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row,
const int input_col, const int output_cols,
T** outputs) {
__global__ void KernelConcatGrad(const T* input_data, const int in_row,
const int in_col, const int fixed_out_col,
T** outputs_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
double inv_input_col = 1.0 / input_col;
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col;
int in_offset = tid_x - split * input_col;
T* output_ptr = outputs[split];
for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) {
int split = tid_x / fixed_out_col;
int in_offset = tid_x - split * fixed_out_col;
T* output_ptr = outputs_data[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * output_cols + in_offset] =
input[tid_y * input_col + tid_x];
for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * fixed_out_col + in_offset] =
input_data[tid_y * in_col + tid_x];
}
}
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
* each dimension must be the same, except the axis dimension.
*/
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
......@@ -136,41 +134,40 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
const std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output) {
// TODO(zcd): Add input data validity checking
int num = input.size();
int rows = 1;
int in_num = input.size();
int in_row = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
in_row *= dim_0[i];
}
int cols = input[0].numel() / rows;
int out_rows = rows, out_cols = 0;
int in_col = input[0].numel() / in_row;
int out_row = in_row, out_col = 0;
framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2);
framework::Vector<int> inputs_cols(num + 1);
inputs_cols[0] = 0;
framework::Vector<int16_t> inputs_data(in_num * sizeof(T*) / 2);
framework::Vector<int> inputs_col(in_num + 1);
T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());
inputs_col[0] = 0;
bool sameShape = true;
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
for (int i = 0; i < in_num; ++i) {
int t_cols = input[i].numel() / in_row;
if (sameShape) {
if (t_cols != cols) sameShape = false;
if (t_cols != in_col) sameShape = false;
}
out_cols += t_cols;
inputs_cols[i + 1] = out_cols;
out_col += t_cols;
inputs_col[i + 1] = out_col;
inputs_ptr[i] = const_cast<T*>(input[i].data<T>());
}
T** ins_gpu =
T** dev_ins_data =
reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace()));
const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace());
// computation
// set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((out_cols + 31) >> 5) << 5;
if (out_col < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((out_col + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
dim3 block_size = dim3(block_cols, block_rows, 1);
......@@ -179,25 +176,26 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((out_cols + block_cols - 1) / block_cols, max_blocks);
std::min((out_col + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1));
std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (sameShape) {
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
ins_gpu, cols, out_rows, out_cols, output->data<T>());
dev_ins_data, in_col, out_row, out_col, output->data<T>());
} else {
const int* dev_ins_col_data = inputs_col.CUDAData(context.GetPlace());
KernelConcat<<<grid_size, block_size, 0, context.stream()>>>(
ins_gpu, ins_col_gpu, static_cast<int>(inputs_cols.size()), out_rows,
out_cols, output->data<T>());
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()),
out_row, out_col, output->data<T>());
}
}
};
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
* each dimension must be the same, except the axis dimension.
*/
template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
......@@ -206,41 +204,40 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>& outputs) {
// TODO(zcd): Add input data validity checking
int num = outputs.size();
int input_row = 1;
int o_num = outputs.size();
int out_row = 1;
auto dim_0 = outputs[0].dims();
for (int i = 0; i < axis; ++i) {
input_row *= dim_0[i];
out_row *= dim_0[i];
}
int output_col_0 = outputs[0].numel() / input_row;
int input_col = 0;
int out_col = outputs[0].numel() / out_row;
int in_col = 0, in_row = out_row;
bool sameShape = true;
framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2);
framework::Vector<int> outputs_cols(num + 1);
outputs_cols[0] = 0;
framework::Vector<int16_t> outputs_data(o_num * sizeof(T*) / 2);
framework::Vector<int> outputs_cols(o_num + 1);
T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
for (int i = 0; i < num; ++i) {
int t_col = outputs[i].numel() / input_row;
outputs_cols[0] = 0;
for (int i = 0; i < o_num; ++i) {
int t_col = outputs[i].numel() / out_row;
if (sameShape) {
if (t_col != output_col_0) sameShape = false;
if (t_col != out_col) sameShape = false;
}
input_col += t_col;
outputs_cols[i + 1] = input_col;
in_col += t_col;
outputs_cols[i + 1] = in_col;
outputs_ptr[i] = outputs[i].data<T>();
}
T** outs_gpu =
T** dev_out_gpu_data =
reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace()));
const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace());
// computation
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (input_col < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((input_col + 31) >> 5) << 5;
if (in_col < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((in_col + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
dim3 block_size = dim3(block_cols, block_rows, 1);
......@@ -249,18 +246,19 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((input_col + block_cols - 1) / block_cols, max_blocks);
std::min((in_col + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(input_row / block_rows, 1));
std::min(max_blocks / grid_cols, std::max(out_row / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (sameShape) {
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), input_row, input_col, output_col_0, outs_gpu);
input.data<T>(), in_row, in_col, out_col, dev_out_gpu_data);
} else {
const int* dev_outs_col_data = outputs_cols.CUDAData(context.GetPlace());
KernelConcatGrad<<<grid_size, block_size, 0, context.stream()>>>(
input.data<T>(), input_row, input_col, outs_col_gpu,
static_cast<int>(outputs_cols.size()), outs_gpu);
input.data<T>(), in_row, in_col, dev_outs_col_data,
static_cast<int>(outputs_cols.size()), dev_out_gpu_data);
}
}
};
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
template <typename Functor>
class MKLDNNActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(context.Input<framework::Tensor>("X") != nullptr,
"Cannot get input tensor X, variable name = %s",
context.op().Input("X"));
PADDLE_ENFORCE(context.Output<framework::Tensor>("Out") != nullptr,
"Cannot find output tensor Out, variable name = %s",
context.op().Output("Out"));
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(context);
}
};
template <typename Functor>
class MKLDNNActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(context);
}
};
namespace {
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper) {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace(), layout, library);
}
} // anonymous namespace
class ActivationWithMKLDNNOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this);
}
};
class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include <future>
#include "paddle/fluid/operators/detail/grpc_client.h"
namespace paddle {
namespace operators {
class SendBarrierOp : public framework::OperatorBase {
public:
SendBarrierOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
// need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
}
};
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SendBarrierOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC(
SendBarrier operator
This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent.
)DOC");
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.")
.SetDefault({"127.0.0.1:6164"});
}
};
class SendBarrierOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendBarrierOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(send_barrier, ops::SendBarrierOp,
paddle::framework::EmptyGradOpMaker, ops::SendBarrierOpMaker,
ops::SendBarrierOpVarTypeInference,
ops::SendBarrierOpShapeInference);
......@@ -21,6 +21,7 @@ limitations under the License. */
#include <future>
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
......@@ -59,6 +60,9 @@ class SendOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include <future>
#include "paddle/fluid/operators/detail/grpc_client.h"
namespace paddle {
namespace operators {
static bool NeedSend(const framework::Scope& scope,
const std::string& varname) {
auto* var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
varname);
if (var->IsType<framework::LoDTensor>()) {
return var->Get<framework::LoDTensor>().IsInitialized();
} else if (var->IsType<framework::SelectedRows>()) {
return var->Get<framework::SelectedRows>().rows().size() > 0UL;
} else {
PADDLE_THROW(
"Variable type in send side should be in "
"[LodTensor, SelectedRows]");
}
return false;
}
class SendVarsOp : public framework::OperatorBase {
public:
SendVarsOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto ins = Inputs("X");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
int sync_send = Attr<int>("sync_sent");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
auto client_var_name = Output("RPCClient");
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
// TODO(Yancey1989): we need to use an IO threadpool which has
// a larger number of threads than the computing threadpool.
rpc_client->AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
if (sync_send) {
rpc_client->Wait();
}
}
};
class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SendVarsOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
.AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddComment(R"DOC(
Send operator
This operator will send variables to listen_and_serve op at the parameter server.
)DOC");
AddAttr<int>("ync_send",
"(int, default 0)"
"sync send or async send.")
.SetDefault(0);
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input "
"variables for mapping")
.SetDefault({"127.0.0.1:6164"});
}
};
class SendVarsOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendVarsOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(send_vars, ops::SendVarsOp,
paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker,
ops::SendVarsOpVarTypeInference,
ops::SendVarsOpShapeInference);
......@@ -49,7 +49,7 @@ nv_test(device_context_test SRCS device_context_test.cu DEPS device_context gpu_
nv_test(cudnn_helper_test SRCS cudnn_helper_test.cc DEPS dynload_cuda)
nv_test(transform_test SRCS transform_test.cu DEPS paddle_memory place device_context)
cc_library(device_tracer SRCS device_tracer.cc DEPS profiler_proto ${GPU_CTX_DEPS})
cc_library(device_tracer SRCS device_tracer.cc DEPS boost profiler_proto ${GPU_CTX_DEPS})
cc_library(profiler SRCS profiler.cc DEPS device_context device_tracer)
cc_test(profiler_test SRCS profiler_test.cc DEPS profiler)
......
......@@ -403,6 +403,8 @@ class LayerHelper(object):
if 'use_mkldnn' in self.kwargs:
act['use_mkldnn'] = self.kwargs.get('use_mkldnn')
act_type = act.pop('type')
if 'use_mkldnn' in self.kwargs:
act['use_mkldnn'] = self.kwargs.get('use_mkldnn')
self.append_op(
type=act_type,
inputs={"X": [input_var]},
......
......@@ -506,5 +506,54 @@ class TestSwish(OpTest):
self.check_grad(['X'], 'Out', max_relative_error=0.008)
#--------------------test MKLDNN--------------------
class TestMKLDNNRelu(TestRelu):
def setUp(self):
super(TestMKLDNNRelu, self).setUp()
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32")
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
out = np.maximum(x, 0)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True}
class TestMKLDNNTanh(TestTanh):
def setUp(self):
super(TestMKLDNNTanh, self).setUp()
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
}
self.outputs = {'Out': np.tanh(self.inputs['X'])}
self.attrs = {"use_mkldnn": True}
class TestMKLDNNSqrt(TestSqrt):
def setUp(self):
super(TestMKLDNNSqrt, self).setUp()
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
}
self.outputs = {'Out': np.sqrt(self.inputs['X'])}
self.attrs = {"use_mkldnn": True}
class TestMKLDNNAbs(TestAbs):
def setUp(self):
super(TestMKLDNNAbs, self).setUp()
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32")
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x}
self.outputs = {'Out': np.abs(self.inputs['X'])}
self.attrs = {"use_mkldnn": True}
if __name__ == "__main__":
unittest.main()
......@@ -20,19 +20,35 @@ from op_test import OpTest
class TestConcatOp(OpTest):
def setUp(self):
self.op_type = "concat"
x0 = np.random.random((2, 1, 4, 5)).astype('float32')
x1 = np.random.random((2, 2, 4, 5)).astype('float32')
x2 = np.random.random((2, 3, 4, 5)).astype('float32')
axis = 1
self.inputs = {'X': [('x0', x0), ('x1', x1), ('x2', x2)]}
self.attrs = {'axis': axis}
self.outputs = {'Out': np.concatenate((x0, x1, x2), axis=axis)}
self.init_test_data()
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}
self.attrs = {'axis': self.axis}
self.outputs = {
'Out': np.concatenate(
(self.x0, self.x1, self.x2), axis=self.axis)
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['x0'], 'Out')
self.check_grad(['x1'], 'Out')
self.check_grad(['x2'], 'Out')
def init_test_data(self):
self.x0 = np.random.random((2, 1, 4, 5)).astype('float32')
self.x1 = np.random.random((2, 2, 4, 5)).astype('float32')
self.x2 = np.random.random((2, 3, 4, 5)).astype('float32')
self.axis = 1
class TestConcatOp2(OpTest):
def init_test_data(self):
self.x0 = np.random.random((2, 3, 4, 5)).astype('float32')
self.x1 = np.random.random((2, 3, 4, 5)).astype('float32')
self.x2 = np.random.random((2, 3, 4, 5)).astype('float32')
self.axis = 1
if __name__ == '__main__':
......
......@@ -126,7 +126,6 @@ class TestTensor(unittest.TestCase):
def test_lod_tensor_gpu_init(self):
if not core.is_compiled_with_cuda():
return
scope = core.Scope()
place = core.CUDAPlace(0)
lod_py = [[0, 2, 5], [0, 2, 4, 5]]
lod_tensor = core.LoDTensor()
......@@ -144,6 +143,25 @@ class TestTensor(unittest.TestCase):
self.assertAlmostEqual(2.0, lod_v[0, 0, 0, 1])
self.assertListEqual(lod_py, lod_tensor.lod())
def test_empty_tensor(self):
place = core.CPUPlace()
scope = core.Scope()
var = scope.var("test_tensor")
tensor = var.get_tensor()
tensor.set_dims([0, 1])
tensor.alloc_float(place)
tensor_array = numpy.array(tensor)
self.assertEqual((0, 1), tensor_array.shape)
if core.is_compiled_with_cuda():
gpu_place = core.CUDAPlace(0)
tensor.alloc_float(gpu_place)
tensor_array = numpy.array(tensor)
self.assertEqual((0, 1), tensor_array.shape)
if __name__ == '__main__':
unittest.main()
......@@ -16,7 +16,7 @@ Creator package contains some simple reader creator, which could
be used in user program.
"""
__all__ = ['np_array', 'text_file', "cloud_reader"]
__all__ = ['np_array', 'text_file', 'recordio', 'cloud_reader']
def np_array(x):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册