提交 458a2da3 编写于 作者: F fengjiayi

Merge branch 'fix_bugs' into dev_opdesc_in_python

......@@ -88,25 +88,30 @@ class Tensor {
* @brief Copy the content of external tensor to a new place.
*
* @param[in] src The external tensor.
* @param[in] ctx The device context contains place where to store.
* @param[in] dst_place The dst place.
* @param[in] ctx The device context contains device resources.
*
* @note CopyFrom supports CPU <-> GPU, GPU <-> GPU.
*/
// TODO(qijun): https://github.com/PaddlePaddle/Paddle/issues/4647
// Remove `CopyFrom` and `CopyFromVector` from Tensor interface
// and make them global functions
template <typename T>
inline void CopyFrom(const Tensor& src, const platform::Place& dst_place);
inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
const platform::DeviceContext& ctx);
/**
* @brief Copy the content of an external vector to a tensor.
*
* @param[in] src The external vector.
* @param[in] ctx The device context contains place where to store.
* @param[in] src The external tensor.
* @param[in] ctx The device context contains device resources.
*
* * @note CopyFromVector assumes that the tensor has been resized
* before invoking.
*/
template <typename T>
inline void CopyFromVector(const std::vector<T>& src,
const platform::Place& dst_place);
const platform::DeviceContext& ctx);
/**
* @brief Return the slice of the tensor.
......
......@@ -95,7 +95,8 @@ void TensorArray::Write(size_t index, const LoDTensor& value) {
values_[index].Resize(value.dims());
values_[index].mutable_data<value_type>(platform::CPUPlace());
values_[index].CopyFrom<value_type>(value, platform::CPUPlace());
values_[index].CopyFrom<value_type>(value, platform::CPUPlace(),
platform::CPUDeviceContext());
}
void TensorArray::WriteShared(size_t index, const LoDTensor& value) {
......@@ -151,7 +152,8 @@ LoDTensor TensorArray::Stack() const {
for (size_t idx = 0; idx < size(); idx++) {
result.Slice<value_type>(idx, idx + 1)
.CopyFrom<value_type>(Read(idx), platform::CPUPlace());
.CopyFrom<value_type>(Read(idx), platform::CPUPlace(),
platform::CPUDeviceContext());
}
return result;
}
......@@ -182,7 +184,8 @@ void TensorArray::Unstack(const LoDTensor& source, bool data_shared) const {
// copy
value.Resize(value_dims);
value.CopyFrom<value_type>(source.Slice<value_type>(elem, elem + 1),
platform::CPUPlace());
platform::CPUPlace(),
platform::CPUDeviceContext());
}
}
}
......@@ -236,7 +239,8 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
auto target = result.Slice<value_type>(i, i + 1);
auto source_ = source->Slice<value_type>(index, index + 1);
target.CopyFrom<value_type>(source_, platform::CPUPlace());
target.CopyFrom<value_type>(source_, platform::CPUPlace(),
platform::CPUDeviceContext());
}
return result;
......@@ -269,7 +273,8 @@ LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source,
if (index >= seq_meta.end) break;
auto source_ = source[batch_id].Slice<float>(seq_id, seq_id + 1);
auto target = result.Slice<float>(index, index + 1);
target.CopyFrom<float>(source_, platform::CPUPlace());
target.CopyFrom<float>(source_, platform::CPUPlace(),
platform::CPUDeviceContext());
}
}
......
......@@ -88,7 +88,8 @@ inline Tensor& Tensor::ShareDataWith(const Tensor& src) {
template <typename T>
inline void Tensor::CopyFrom(const Tensor& src,
const platform::Place& dst_place) {
const platform::Place& dst_place,
const platform::DeviceContext& ctx) {
src.check_memory_size<T>();
Resize(src.dims());
......@@ -106,26 +107,45 @@ inline void Tensor::CopyFrom(const Tensor& src,
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src_place) &&
platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0);
auto src_gpu_place = boost::get<platform::GPUPlace>(src_place);
auto dst_cpu_place = boost::get<platform::CPUPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size, 0);
auto src_cpu_place = boost::get<platform::CPUPlace>(src_place);
auto dst_gpu_place = boost::get<platform::GPUPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place);
PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr,
boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0);
auto src_gpu_place = boost::get<platform::GPUPlace>(src_place);
auto dst_gpu_place = boost::get<platform::GPUPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
PADDLE_ENFORCE(cudaStreamSynchronize(0),
"cudaStreamSynchronize failed in Tensor CopyFrom");
#endif
}
template <typename T>
inline void Tensor::CopyFromVector(const std::vector<T>& src,
const platform::Place& dst_place) {
const platform::DeviceContext& ctx) {
auto dst_place = ctx.GetPlace();
auto src_ptr = static_cast<const void*>(src.data());
platform::CPUPlace src_place;
auto dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
......@@ -137,12 +157,11 @@ inline void Tensor::CopyFromVector(const std::vector<T>& src,
}
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr, src_place,
src_ptr, size, 0);
memory::Copy(
boost::get<platform::GPUPlace>(dst_place), dst_ptr, src_place, src_ptr,
size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
PADDLE_ENFORCE(cudaStreamSynchronize(0),
"cudaStreamSynchronize failed in Tensor CopyFromVector");
#endif
}
......
......@@ -194,6 +194,7 @@ TEST(Tensor, CopyFrom) {
{
Tensor src_tensor;
Tensor dst_tensor;
CPUDeviceContext cpu_ctx((CPUPlace()));
int* src_ptr = src_tensor.mutable_data<int>(make_ddim({3, 3}), CPUPlace());
......@@ -201,7 +202,7 @@ TEST(Tensor, CopyFrom) {
memcpy(src_ptr, arr, 9 * sizeof(int));
auto cpu_place = new paddle::platform::CPUPlace();
dst_tensor.CopyFrom<int>(src_tensor, *cpu_place);
dst_tensor.CopyFrom<int>(src_tensor, *cpu_place, cpu_ctx);
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr);
......@@ -210,7 +211,7 @@ TEST(Tensor, CopyFrom) {
}
Tensor slice_tensor = src_tensor.Slice<int>(1, 2);
dst_tensor.CopyFrom<int>(slice_tensor, *cpu_place);
dst_tensor.CopyFrom<int>(slice_tensor, *cpu_place, cpu_ctx);
const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr);
......@@ -231,13 +232,15 @@ TEST(Tensor, CopyFrom) {
// CPU Tensor to GPU Tensor
auto gpu_place = new paddle::platform::GPUPlace(0);
gpu_tensor.CopyFrom<int>(src_tensor, *gpu_place);
CUDADeviceContext gpu_ctx(*gpu_place);
gpu_tensor.CopyFrom<int>(src_tensor, *gpu_place, gpu_ctx);
// GPU Tensor to CPU Tensor
auto cpu_place = new paddle::platform::CPUPlace();
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_place);
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_place, gpu_ctx);
// Compare Tensors
// Sync before Compare Tensors
gpu_ctx.Wait();
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr);
for (size_t i = 0; i < 9; ++i) {
......@@ -247,12 +250,13 @@ TEST(Tensor, CopyFrom) {
Tensor slice_tensor = src_tensor.Slice<int>(1, 2);
// CPU Slice Tensor to GPU Tensor
gpu_tensor.CopyFrom<int>(slice_tensor, *gpu_place);
gpu_tensor.CopyFrom<int>(slice_tensor, *gpu_place, gpu_ctx);
// GPU Tensor to CPU Tensor
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_place);
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_place, gpu_ctx);
// Compare Slice Tensors
// Sync before Compare Slice Tensors
gpu_ctx.Wait();
const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr);
......@@ -273,7 +277,8 @@ TEST(Tensor, CopyFromVector) {
// Copy to CPU Tensor
cpu_tensor.Resize(make_ddim({3, 3}));
auto cpu_place = new paddle::platform::CPUPlace();
cpu_tensor.CopyFromVector<int>(src_vec, *cpu_place);
CPUDeviceContext cpu_ctx(*cpu_place);
cpu_tensor.CopyFromVector<int>(src_vec, cpu_ctx);
// Compare Tensors
const int* cpu_ptr = cpu_tensor.data<int>();
......@@ -285,7 +290,7 @@ TEST(Tensor, CopyFromVector) {
src_vec.erase(src_vec.begin(), src_vec.begin() + 5);
cpu_tensor.Resize(make_ddim({2, 2}));
cpu_tensor.CopyFromVector<int>(src_vec, *cpu_place);
cpu_tensor.CopyFromVector<int>(src_vec, cpu_ctx);
cpu_ptr = cpu_tensor.data<int>();
src_ptr = src_vec.data();
ASSERT_NE(src_ptr, cpu_ptr);
......@@ -306,16 +311,19 @@ TEST(Tensor, CopyFromVector) {
// Copy to CPU Tensor
cpu_tensor.Resize(make_ddim({3, 3}));
auto cpu_place = new paddle::platform::CPUPlace();
cpu_tensor.CopyFromVector<int>(src_vec, *cpu_place);
CPUDeviceContext cpu_ctx(*cpu_place);
cpu_tensor.CopyFromVector<int>(src_vec, cpu_ctx);
// Copy to GPUTensor
gpu_tensor.Resize(make_ddim({3, 3}));
auto gpu_place = new paddle::platform::GPUPlace();
gpu_tensor.CopyFromVector<int>(src_vec, *gpu_place);
CUDADeviceContext gpu_ctx(*gpu_place);
gpu_tensor.CopyFromVector<int>(src_vec, gpu_ctx);
// Copy from GPU to CPU tensor for comparison
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_place);
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_place, gpu_ctx);
// Compare Tensors
// Sync before Compare Tensors
gpu_ctx.Wait();
const int* src_ptr = src_vec.data();
const int* cpu_ptr = cpu_tensor.data<int>();
const int* dst_ptr = dst_tensor.data<int>();
......@@ -329,11 +337,13 @@ TEST(Tensor, CopyFromVector) {
src_vec.erase(src_vec.begin(), src_vec.begin() + 5);
cpu_tensor.Resize(make_ddim({2, 2}));
cpu_tensor.CopyFromVector<int>(src_vec, *cpu_place);
cpu_tensor.CopyFromVector<int>(src_vec, cpu_ctx);
gpu_tensor.Resize(make_ddim({2, 2}));
gpu_tensor.CopyFromVector<int>(src_vec, *gpu_place);
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_place);
gpu_tensor.CopyFromVector<int>(src_vec, gpu_ctx);
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_place, gpu_ctx);
// Sync before Compare Tensors
gpu_ctx.Wait();
src_ptr = src_vec.data();
cpu_ptr = cpu_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
......
......@@ -34,6 +34,7 @@ inline std::vector<T> RepeatedToVector(
template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T> &vec,
RepeatedField *repeated_field) {
repeated_field->Clear();
repeated_field->Reserve(vec.size());
for (const auto &elem : vec) {
*repeated_field->Add() = elem;
......@@ -44,6 +45,7 @@ inline void VectorToRepeated(const std::vector<T> &vec,
template <typename RepeatedField>
inline void VectorToRepeated(const std::vector<bool> &vec,
RepeatedField *repeated_field) {
repeated_field->Clear();
repeated_field->Reserve(vec.size());
for (auto elem : vec) {
*repeated_field->Add() = elem;
......
......@@ -321,6 +321,23 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
template <typename AttrType>
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ThresholdedReluOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Y", "Output of ThresholdedRelu operator");
AddComment(
"ThresholdedRelu activation operator, "
"thresholded_relu = x for x > threshold, "
"thresholded_relu = 0 otherwise.");
AddAttr<AttrType>("threshold", "The threshold location of activation")
.SetDefault(static_cast<AttrType>(1.0));
}
};
} // namespace operators
} // namespace paddle
......@@ -392,6 +409,10 @@ REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad,
REGISTER_OP(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker<float>,
hard_shrink_grad, ops::ActivationOpGrad);
REGISTER_OP(thresholded_relu, ops::ActivationOp,
ops::ThresholdedReluOpMaker<float>, thresholded_relu_grad,
ops::ActivationOpGrad);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \
act_type, \
......
......@@ -590,6 +590,32 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = (x > static_cast<T>(threshold)).template cast<T>() * x;
}
};
template <typename T>
struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (x > static_cast<T>(threshold)).template cast<T>();
}
};
} // namespace operators
} // namespace paddle
......@@ -615,4 +641,5 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(elu, ELUFunctor, ELUGradFunctor); \
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor)
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \
__macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);
......@@ -34,7 +34,7 @@ class FeedKernel : public framework::OpKernel<T> {
// TODO(qijun):
// check tensors[col].dims() with attribute,
// except the first dimenson.
out->CopyFrom<T>(tensors[col], ctx.GetPlace());
out->CopyFrom<T>(tensors[col], ctx.GetPlace(), ctx.device_context());
}
};
......
......@@ -35,7 +35,8 @@ class FetchKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_GT(tensors->size(), static_cast<size_t>(col));
(*tensors)[col].Resize(input->dims());
(*tensors)[col].mutable_data<T>(platform::CPUPlace());
(*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace());
(*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace(),
ctx.device_context());
// TODO(qijun): need to handle LodTensor later
}
};
......
......@@ -49,10 +49,22 @@ void testIm2col() {
memcpy(input_ptr, arr, 6 * sizeof(float));
auto* place = new Place();
paddle::platform::DeviceContext* context;
if (paddle::platform::is_cpu_place(*place)) {
context =
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
} else {
#ifdef PADDLE_WITH_CUDA
context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
#else
PADDLE_THROW("no GPU support");
#endif // PADDLE_ONLY_CPU
}
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom<float>(input_tmp, *place);
input.CopyFrom<float>(input_tmp, *place, *context);
}
output_cfo.mutable_data<float>(
{1, filter_size, filter_size, output_height, output_width}, *place);
......@@ -66,18 +78,6 @@ void testIm2col() {
paddle::operators::math::ColFormat::kOCF, Place, float>
im2col_ocf;
paddle::platform::DeviceContext* context;
if (paddle::platform::is_cpu_place(*place)) {
context =
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
} else {
#ifdef PADDLE_WITH_CUDA
context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
#else
PADDLE_THROW("no GPU support");
#endif // PADDLE_ONLY_CPU
}
im2col(*context, input, output_cfo, stride, stride, padding, padding);
im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding);
......@@ -85,7 +85,8 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) {
out_cfo_ptr = output_cfo.data<float>();
} else {
output_tmp.CopyFrom<float>(output_cfo, paddle::platform::CPUPlace());
output_tmp.CopyFrom<float>(output_cfo, paddle::platform::CPUPlace(),
*context);
out_cfo_ptr = output_tmp.data<float>();
}
EXPECT_EQ(out_cfo_ptr[0], 0);
......@@ -101,7 +102,8 @@ void testIm2col() {
if (paddle::platform::is_cpu_place(*place)) {
out_ocf_ptr = output_ocf.data<float>();
} else {
output_tmp.CopyFrom<float>(output_ocf, paddle::platform::CPUPlace());
output_tmp.CopyFrom<float>(output_ocf, paddle::platform::CPUPlace(),
*context);
out_ocf_ptr = output_tmp.data<float>();
}
EXPECT_EQ(out_ocf_ptr[0], 0);
......
......@@ -17,17 +17,18 @@ TEST(math_function, notrans_mul_trans) {
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input1, *gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place, context);
input2_gpu.CopyFrom<float>(input1, *gpu_place, context);
out_gpu.mutable_data<float>({2, 2}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0);
out.CopyFrom<float>(out_gpu, *cpu_place);
out.CopyFrom<float>(out_gpu, *cpu_place, context);
float* out_ptr = out.data<float>();
context.Wait();
EXPECT_EQ(out_ptr[0], 5);
EXPECT_EQ(out_ptr[1], 14);
EXPECT_EQ(out_ptr[2], 14);
......@@ -50,17 +51,18 @@ TEST(math_function, trans_mul_notrans) {
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input1, *gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place, context);
input2_gpu.CopyFrom<float>(input1, *gpu_place, context);
out_gpu.mutable_data<float>({3, 3}, *gpu_place);
paddle::operators::math::matmul<paddle::platform::GPUPlace, float>(
context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0);
out.CopyFrom<float>(out_gpu, *cpu_place);
out.CopyFrom<float>(out_gpu, *cpu_place, context);
float* out_ptr = out.data<float>();
context.Wait();
EXPECT_EQ(out_ptr[0], 9);
EXPECT_EQ(out_ptr[1], 12);
EXPECT_EQ(out_ptr[2], 15);
......@@ -98,9 +100,9 @@ TEST(math_function, gemm_notrans_cublas) {
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input2, *gpu_place);
input3_gpu.CopyFrom<float>(input3, *gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place, context);
input2_gpu.CopyFrom<float>(input2, *gpu_place, context);
input3_gpu.CopyFrom<float>(input3, *gpu_place, context);
float* a = input1_gpu.data<float>();
float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(*gpu_place);
......@@ -108,7 +110,7 @@ TEST(math_function, gemm_notrans_cublas) {
paddle::operators::math::gemm<paddle::platform::GPUPlace, float>(
context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4);
input3.CopyFrom<float>(input3_gpu, *cpu_place);
input3.CopyFrom<float>(input3_gpu, *cpu_place, context);
// numpy code:
// a = np.arange(6).reshape(2, 3)
......@@ -116,6 +118,7 @@ TEST(math_function, gemm_notrans_cublas) {
// c = np.arange(8).reshape(2, 4)[:, 1:]
// out = np.arange(8).reshape(2, 4)
// out[:, 1:] = np.dot(a, b) + c
context.Wait();
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);
EXPECT_EQ(input3_ptr[2], 28);
......@@ -152,9 +155,9 @@ TEST(math_function, gemm_trans_cublas) {
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place);
input2_gpu.CopyFrom<float>(input2, *gpu_place);
input3_gpu.CopyFrom<float>(input3, *gpu_place);
input1_gpu.CopyFrom<float>(input1, *gpu_place, context);
input2_gpu.CopyFrom<float>(input2, *gpu_place, context);
input3_gpu.CopyFrom<float>(input3, *gpu_place, context);
float* a = input1_gpu.data<float>();
float* b = input2_gpu.data<float>();
float* c = input3_gpu.mutable_data<float>(*gpu_place);
......@@ -162,7 +165,8 @@ TEST(math_function, gemm_trans_cublas) {
paddle::operators::math::gemm<paddle::platform::GPUPlace, float>(
context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4);
input3.CopyFrom<float>(input3_gpu, *cpu_place);
input3.CopyFrom<float>(input3_gpu, *cpu_place, context);
context.Wait();
EXPECT_EQ(input3_ptr[0], 0);
EXPECT_EQ(input3_ptr[1], 24);
......
......@@ -33,7 +33,8 @@ class MultiplexGPUKernel : public framework::OpKernel<T> {
auto cols = ins[0]->numel() / rows;
// copy index to cpu
Tensor index_t_cpu;
index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace());
index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace(),
ctx.device_context());
auto* index = index_t_cpu.data<int32_t>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
......@@ -70,7 +71,8 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> {
auto cols = ins[0]->numel() / rows;
// copy index to cpu
Tensor index_t_cpu;
index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace());
index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace(),
ctx.device_context());
auto* index = index_t_cpu.data<int32_t>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
......
......@@ -46,7 +46,7 @@ void RecurrentAlgorithm::Run(const Scope& scope,
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len, dev_ctx);
}
void RecurrentAlgorithm::CreateScopes(const Scope& scope,
......@@ -151,12 +151,12 @@ void RecurrentGradientAlgorithm::Run(
auto& step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len);
for (int step_id = seq_len - 1; step_id >= 0; --step_id) {
if (step_id != seq_len - 1) {
if (static_cast<size_t>(step_id) != seq_len - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len, dev_ctx);
LinkBootMemoryGradients(step_scopes[0]);
}
......
......@@ -33,7 +33,7 @@ class ReshapeKernel : public framework::OpKernel<T> {
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto out_dims = framework::make_ddim(shape_int64);
out->CopyFrom<T>(*in, ctx.GetPlace());
out->CopyFrom<T>(*in, ctx.GetPlace(), ctx.device_context());
out->Resize(out_dims);
}
};
......@@ -47,7 +47,7 @@ class ReshapeGradKernel : public framework::OpKernel<T> {
d_x->mutable_data<T>(ctx.GetPlace());
auto in_dims = d_x->dims();
d_x->CopyFrom<T>(*d_out, ctx.GetPlace());
d_x->CopyFrom<T>(*d_out, ctx.GetPlace(), ctx.device_context());
d_x->Resize(in_dims);
}
};
......
......@@ -51,7 +51,7 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& outlinks,
const size_t seq_len) {
const size_t seq_len, const platform::DeviceContext& ctx) {
for (size_t i = 0; i < outlinks.size(); i++) {
auto* output_var = step_scopes[0]->parent().FindVar(outlinks[i]);
PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.",
......@@ -72,7 +72,7 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
// TODO(luotao02) data type and platform::DeviceContext() should set
// correctly
(output->Slice<float>(j, j + 1))
.CopyFrom<float>(*step_output, platform::CPUPlace());
.CopyFrom<float>(*step_output, platform::CPUPlace(), ctx);
}
}
}
......
......@@ -71,7 +71,7 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
*/
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& outlinks,
const size_t seq_len);
const size_t seq_len, const platform::DeviceContext& ctx);
void LinkMemories(const std::vector<Scope*>& step_scopes,
const std::vector<MemoryAttr>& memories, const size_t step_id,
......
......@@ -54,7 +54,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(
ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
"uniform_random's min must less then max");
auto dims = Attr<std::vector<int>>("dims");
auto& dims = ctx->Attrs().Get<std::vector<int>>("dims");
std::vector<int64_t> temp;
temp.reserve(dims.size());
for (auto dim : dims) {
......
......@@ -57,7 +57,18 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
}
framework::Tensor dst_tensor;
if (paddle::platform::is_gpu_place(tensor.place())) {
dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace());
#ifdef PADDLE_WITH_CUDA
auto *src_ptr = static_cast<const void *>(tensor.data<CUR_TYPE>());
auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>(
tensor.dims(), platform::CPUPlace()));
// TODO(qijun): Here we use default CUDA stream to set GPU Tensor to
// a Python numpy array. It's better to manage CDUA stream unifiedly.
paddle::platform::GpuMemcpySync(dst_ptr, src_ptr,
sizeof(CUR_TYPE) * tensor.numel(),
cudaMemcpyDeviceToHost);
#else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#endif
} else if (paddle::platform::is_cpu_place(tensor.place())) {
dst_tensor = tensor;
}
......@@ -120,6 +131,8 @@ void PyCUDATensorSetFromArray(
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(place);
// TODO(qijun): Here we use default CUDA stream to set a Python numpy
// array to a GPU Tensor. It's better to manage CDUA stream unifiedly.
paddle::platform::GpuMemcpySync(dst, array.data(), sizeof(T) * array.size(),
cudaMemcpyHostToDevice);
}
......
......@@ -363,5 +363,26 @@ class TestSoftsign(OpTest):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestThresholdedRelu(OpTest):
def setUp(self):
self.op_type = "thresholded_relu"
threshold = 0.25
self.relative_error = 0.005
X = np.random.uniform(-1, 1, [11, 17]).astype("float32")
# Same reason as TestAbs
X[np.abs(X - threshold) < self.relative_error] = threshold + 0.2
self.inputs = {'X': X}
self.attrs = {'threshold': threshold}
self.outputs = {'Y': (X > threshold) * X}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=self.relative_error)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册