未验证 提交 31392627 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] bugfix for unittest (#32392)

* fix test_unpool_op

* fix test_inplace_addto_strategy

* fix test_conv2d_fusion_op

* fix test_imperative_lod_tensor_to_selected_rows, test_imperative_selected_rows_to_lod_tensor

* fix test_dot_op

* fix test_correlation_op

* fix tracer

* fix test_memcpy_op
上级 efdb0a7d
...@@ -180,7 +180,6 @@ function(op_library TARGET) ...@@ -180,7 +180,6 @@ function(op_library TARGET)
list(REMOVE_ITEM miopen_cu_cc_srcs "affine_grid_cudnn_op.cu.cc") list(REMOVE_ITEM miopen_cu_cc_srcs "affine_grid_cudnn_op.cu.cc")
list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc") list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc")
list(REMOVE_ITEM hip_srcs "cholesky_op.cu") list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
list(REMOVE_ITEM hip_srcs "correlation_op.cu")
list(REMOVE_ITEM hip_srcs "multinomial_op.cu") list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu") list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu")
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS} hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS}
......
...@@ -699,24 +699,51 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -699,24 +699,51 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv backward data --------------------- // ------------------- cudnn conv backward data ---------------------
ScalingParamType<T> alpha = 1.0f; ScalingParamType<T> alpha = 1.0f;
#ifdef PADDLE_WITH_HIP
// MIOPEN ONLY support beta to be 0.0f
ScalingParamType<T> beta = 0.0f;
#else
ScalingParamType<T> beta = ctx.Attr<bool>("use_addto") ? 1.0f : 0.0f; ScalingParamType<T> beta = ctx.Attr<bool>("use_addto") ? 1.0f : 0.0f;
#endif
VLOG(4) << "Conv_grad: use_addto = " << ctx.Attr<bool>("use_addto"); VLOG(4) << "Conv_grad: use_addto = " << ctx.Attr<bool>("use_addto");
if (input_grad) { if (input_grad) {
// When beta is 0, it is unnecessary to reset input_grad. // When beta is 0, it is unnecessary to reset input_grad.
// When beta is 1, the output cannot be reset since addt strategy used. // When beta is 1, the output cannot be reset since addt strategy used.
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
workspace_handle.RunFunc( if (ctx.Attr<bool>("use_addto")) {
[&](void* cudnn_workspace_ptr) { Tensor temp_tensor(transformed_input_grad.type());
PADDLE_ENFORCE_CUDA_SUCCESS( temp_tensor.Resize(transformed_input_grad.dims());
platform::dynload::miopenConvolutionBackwardData( T* temp_tensor_data = temp_tensor.mutable_data<T>(ctx.GetPlace());
handle, &alpha, args1.odesc.desc(), output_grad_data, workspace_handle.RunFunc(
args1.wdesc.desc(), filter_data, args1.cdesc.desc(), [&](void* cudnn_workspace_ptr) {
data_algo, &beta, args1.idesc.desc(), PADDLE_ENFORCE_CUDA_SUCCESS(
transformed_input_grad_data, cudnn_workspace_ptr, platform::dynload::miopenConvolutionBackwardData(
workspace_size)); handle, &alpha, args1.odesc.desc(), output_grad_data,
}, args1.wdesc.desc(), filter_data, args1.cdesc.desc(),
workspace_size); data_algo, &beta, args1.idesc.desc(), temp_tensor_data,
cudnn_workspace_ptr, workspace_size));
},
workspace_size);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenOpTensor(
handle, miopenTensorOpAdd, &alpha, args1.idesc.desc(),
transformed_input_grad_data, &alpha, args1.idesc.desc(),
temp_tensor_data, &beta, args1.idesc.desc(),
transformed_input_grad_data));
} else {
workspace_handle.RunFunc(
[&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionBackwardData(
handle, &alpha, args1.odesc.desc(), output_grad_data,
args1.wdesc.desc(), filter_data, args1.cdesc.desc(),
data_algo, &beta, args1.idesc.desc(),
transformed_input_grad_data, cudnn_workspace_ptr,
workspace_size));
},
workspace_size);
}
#else #else
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
workspace_handle.RunFunc( workspace_handle.RunFunc(
......
...@@ -146,28 +146,8 @@ struct SearchAlgorithm<miopenConvFwdAlgorithm_t> { ...@@ -146,28 +146,8 @@ struct SearchAlgorithm<miopenConvFwdAlgorithm_t> {
cudnn_workspace_ptr, workspace_size, false)); cudnn_workspace_ptr, workspace_size, false));
}; };
if (!exhaustive_search && !deterministic) { workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size); algo = find_result.fwd_algo;
algo = find_result.fwd_algo;
} else {
auto& temp = ctx.cuda_device_context();
AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetForward());
auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());
VLOG(10) << "miopenConvolutionFwdAlgoPerf_t:"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d;
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
return find_result.fwd_algo;
});
}
VLOG(3) << "choose algo " << algo; VLOG(3) << "choose algo " << algo;
return algo; return algo;
} }
...@@ -208,27 +188,8 @@ struct SearchAlgorithm<miopenConvBwdDataAlgorithm_t> { ...@@ -208,27 +188,8 @@ struct SearchAlgorithm<miopenConvBwdDataAlgorithm_t> {
cudnn_workspace_ptr, workspace_size, false)); cudnn_workspace_ptr, workspace_size, false));
}; };
if (!exhaustive_search && !deterministic) { workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size); algo = find_result.bwd_data_algo;
algo = find_result.bwd_data_algo;
} else {
AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetBackwardData());
auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());
VLOG(10) << "miopenConvolutionFwdAlgoPerf_t"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d;
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
return find_result.bwd_data_algo;
});
}
VLOG(3) << "choose algo " << algo; VLOG(3) << "choose algo " << algo;
return algo; return algo;
} }
...@@ -269,27 +230,8 @@ struct SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t> { ...@@ -269,27 +230,8 @@ struct SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t> {
cudnn_workspace_ptr, workspace_size, false)); cudnn_workspace_ptr, workspace_size, false));
}; };
if (!exhaustive_search && !deterministic) { workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size); algo = find_result.bwd_weights_algo;
algo = find_result.bwd_weights_algo;
} else {
AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetBackwardFilter());
auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());
VLOG(10) << "miopenConvolutionFwdAlgoPerf_t:"
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
<< args.s << ", args.p" << args.p << ", args.d" << args.d;
algo = algo_cache.GetAlgorithm(
x_dims, w_dims, args.s, args.p, args.d, 0,
static_cast<int64_t>(args.cudnn_dtype), [&]() {
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
return find_result.bwd_weights_algo;
});
}
VLOG(3) << "choose algo " << algo; VLOG(3) << "choose algo " << algo;
return algo; return algo;
} }
......
...@@ -12,17 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,17 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_WITH_HIP
// HIP not supported yet
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#ifdef __HIPCC__
#define __syncwarp() __all(1)
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#ifdef __HIPCC__
#define THREADS_PER_BLOCK 64
#else
#define THREADS_PER_BLOCK 32 #define THREADS_PER_BLOCK 32
#endif
#define FULL_MASK 0xffffffff #define FULL_MASK 0xffffffff
using framework::Tensor; using framework::Tensor;
...@@ -30,14 +35,22 @@ using framework::Tensor; ...@@ -30,14 +35,22 @@ using framework::Tensor;
template <typename T> template <typename T>
__forceinline__ __device__ T warpReduceSum(T val) { __forceinline__ __device__ T warpReduceSum(T val) {
for (int offset = 16; offset > 0; offset /= 2) { for (int offset = 16; offset > 0; offset /= 2) {
#ifdef __HIPCC__
val += __shfl_down(val, offset);
#else
val += __shfl_down_sync(FULL_MASK, val, offset); val += __shfl_down_sync(FULL_MASK, val, offset);
#endif
} }
return val; return val;
} }
template <typename T> template <typename T>
__forceinline__ __device__ T blockReduceSum(T val) { __forceinline__ __device__ T blockReduceSum(T val) {
#ifdef __HIPCC__
static __shared__ T shared[64];
#else
static __shared__ T shared[32]; static __shared__ T shared[32];
#endif
int lane = threadIdx.x % warpSize; int lane = threadIdx.x % warpSize;
int wid = threadIdx.x / warpSize; int wid = threadIdx.x / warpSize;
...@@ -483,5 +496,3 @@ REGISTER_OP_CUDA_KERNEL(correlation, ops::CorrelationCUDAKernel<float>, ...@@ -483,5 +496,3 @@ REGISTER_OP_CUDA_KERNEL(correlation, ops::CorrelationCUDAKernel<float>,
ops::CorrelationCUDAKernel<double>); ops::CorrelationCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(correlation_grad, ops::CorrelationCUDAGradKernel<float>, REGISTER_OP_CUDA_KERNEL(correlation_grad, ops::CorrelationCUDAGradKernel<float>,
ops::CorrelationCUDAGradKernel<double>); ops::CorrelationCUDAGradKernel<double>);
#endif // not PADDLE_WITH_HIP
...@@ -32,8 +32,7 @@ if (WITH_GPU OR WITH_ROCM) ...@@ -32,8 +32,7 @@ if (WITH_GPU OR WITH_ROCM)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_batch_norm_act);\n") file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_batch_norm_act);\n")
endif() endif()
# conv_fusion_op needs cudnn 7 above # conv_fusion_op needs cudnn 7 above
# HIP not support cudnnConvolutionBiasActivationForward if (NOT ${CUDNN_VERSION} VERSION_LESS 7100)
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 7100))
op_library(conv_fusion_op) op_library(conv_fusion_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n") file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n")
endif() endif()
......
...@@ -18,14 +18,18 @@ limitations under the License. */ ...@@ -18,14 +18,18 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h" #include "paddle/fluid/operators/math/padding.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#endif
DECLARE_int64(cudnn_exhaustive_search_times); DECLARE_int64(cudnn_exhaustive_search_times);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#if CUDNN_VERSION >= 7100 #if PADDLE_WITH_HIP || CUDNN_VERSION >= 7100
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
...@@ -162,7 +166,78 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -162,7 +166,78 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
if (input->dims().size() == 5) { if (input->dims().size() == 5) {
layout = DataLayout::kNCDHW; layout = DataLayout::kNCDHW;
} }
#ifdef PADDLE_WITH_HIP
miopenConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(padding_common, strides, dilations);
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenSetConvolutionGroupCount(cudnn_conv_desc,
groups));
// Now only support NCHW
std::vector<int> bias_dim = {
1, static_cast<int>(transformed_output.dims()[1]), 1, 1};
miopenTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize<int>(transformed_input.dims()));
miopenTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize<int>(transformed_output.dims()));
miopenTensorDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize<int>(filter->dims()));
miopenTensorDescriptor_t cudnn_bias_desc =
bias_desc.descriptor<T>(layout, bias_dim);
miopenActivationDescriptor_t cudnn_act_desc =
act_desc.descriptor<T>(activation);
miopenConvFwdAlgorithm_t algo;
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto x_dims = framework::vectorize(transformed_input.dims());
auto f_dims = framework::vectorize(filter->dims());
size_t workspace_size = 0;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionForwardGetWorkSpaceSize(
handle, cudnn_filter_desc, cudnn_input_desc, cudnn_conv_desc,
cudnn_output_desc, &workspace_size));
int find_count;
miopenConvAlgoPerf_t find_result;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenFindConvolutionForwardAlgorithm(
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc, output_data,
kNUM_CUDNN_FWD_ALGS, &find_count, &find_result,
cudnn_workspace_ptr, workspace_size, false));
};
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.fwd_algo;
VLOG(3) << "cuDNN forward algo " << algo;
{
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto cudnn_func = [&](void* cudnn_workspace) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, algo, &beta, cudnn_output_desc,
output_data, cudnn_workspace, workspace_size));
};
workspace_handle.RunFunc(cudnn_func, workspace_size);
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::miopenConvolutionForwardBias(
handle, &alpha, cudnn_bias_desc, bias_data, &beta,
cudnn_output_desc, output_data));
if (activation != "identity") {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenActivationForward(
handle, cudnn_act_desc, &alpha, cudnn_output_desc, output_data,
&beta, cudnn_output_desc, output_data));
}
if (residual) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenOpTensor(
handle, miopenTensorOpAdd, &alpha, cudnn_output_desc, output_data,
&alpha, cudnn_output_desc, residual_data, &beta, cudnn_output_desc,
output_data));
}
}
#else // PADDLE_WITH_HIP
cudnnConvolutionDescriptor_t cudnn_conv_desc = cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(padding_common, strides, dilations); conv_desc.descriptor<T>(padding_common, strides, dilations);
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
...@@ -327,6 +402,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -327,6 +402,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
}; };
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
#endif
std::vector<int> channels = ctx.Attr<std::vector<int>>("split_channels"); std::vector<int> channels = ctx.Attr<std::vector<int>>("split_channels");
if (channels.size()) { if (channels.size()) {
auto outs = ctx.MultiOutput<framework::Tensor>("Outputs"); auto outs = ctx.MultiOutput<framework::Tensor>("Outputs");
...@@ -358,8 +434,11 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -358,8 +434,11 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#if CUDNN_VERSION >= 7100
namespace ops = paddle::operators; namespace ops = paddle::operators;
#if CUDNN_VERSION >= 7100
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>, REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>,
ops::CUDNNConvFusionOpKernel<double>); ops::CUDNNConvFusionOpKernel<double>);
#endif #endif
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>);
#endif
...@@ -87,7 +87,11 @@ class Unpool2dMaxFunctor<platform::CUDADeviceContext, T> { ...@@ -87,7 +87,11 @@ class Unpool2dMaxFunctor<platform::CUDADeviceContext, T> {
const T* input_data = input.data<T>(); const T* input_data = input.data<T>();
const int* indices_data = indices.data<int>(); const int* indices_data = indices.data<int>();
T* output_data = output->mutable_data<T>(context.GetPlace()); T* output_data = output->mutable_data<T>(context.GetPlace());
#ifdef __HIPCC__
int threads = 256;
#else
int threads = 1024; int threads = 1024;
#endif
int grid = (input.numel() + threads - 1) / threads; int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMax<T><<<grid, threads, 0, context.stream()>>>( KernelUnpool2dMax<T><<<grid, threads, 0, context.stream()>>>(
input.numel(), input_data, indices_data, input_height, input_width, input.numel(), input_data, indices_data, input_height, input_width,
...@@ -117,7 +121,11 @@ class Unpool2dMaxGradFunctor<platform::CUDADeviceContext, T> { ...@@ -117,7 +121,11 @@ class Unpool2dMaxGradFunctor<platform::CUDADeviceContext, T> {
const T* output_data = output.data<T>(); const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>(); const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
#ifdef __HIPCC__
int threads = 256;
#else
int threads = 1024; int threads = 1024;
#endif
int grid = (input.numel() + threads - 1) / threads; int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMaxGrad<T><<<grid, threads, 0, context.stream()>>>( KernelUnpool2dMaxGrad<T><<<grid, threads, 0, context.stream()>>>(
input.numel(), input_data, indices_data, input_height, input_width, input.numel(), input_data, indices_data, input_height, input_width,
......
...@@ -141,7 +141,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double, ...@@ -141,7 +141,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double,
ops::MemcpyKernel, plat::float16, ops::MemcpyKernel, plat::float16,
ops::MemcpyKernel); ops::MemcpyKernel);
#ifdef PADDLE_WITH_CUDA #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_ROCM)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double, REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double,
ops::MemcpyKernel, int, ops::MemcpyKernel, ops::MemcpyKernel, int, ops::MemcpyKernel,
int64_t, ops::MemcpyKernel, bool, int64_t, ops::MemcpyKernel, bool,
......
...@@ -110,6 +110,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); ...@@ -110,6 +110,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(miopenActivationBackward); \ __macro(miopenActivationBackward); \
__macro(miopenConvolutionBackwardWeights); \ __macro(miopenConvolutionBackwardWeights); \
__macro(miopenConvolutionForward); \ __macro(miopenConvolutionForward); \
__macro(miopenConvolutionForwardBias); \
__macro(miopenConvolutionBackwardBias); \ __macro(miopenConvolutionBackwardBias); \
__macro(miopenConvolutionForwardGetWorkSpaceSize); \ __macro(miopenConvolutionForwardGetWorkSpaceSize); \
__macro(miopenConvolutionBackwardDataGetWorkSpaceSize); \ __macro(miopenConvolutionBackwardDataGetWorkSpaceSize); \
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest, skip_check_grad_ci from op_test import OpTest, skip_check_grad_ci
...@@ -39,13 +40,33 @@ class DotOp(OpTest): ...@@ -39,13 +40,33 @@ class DotOp(OpTest):
self.check_output() self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out') if core.is_compiled_with_rocm():
self.check_grad(
['X', 'Y'],
'Out',
user_defined_grads=[self.inputs['Y'], self.inputs['X']])
else:
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self): def test_check_grad_ingore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set("X")) if core.is_compiled_with_rocm():
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[self.inputs['X']])
else:
self.check_grad(['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y')) if core.is_compiled_with_rocm():
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
user_defined_grads=[self.inputs['Y']])
else:
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
def init_input_output(self): def init_input_output(self):
self.x = np.random.uniform(0.1, 1, [121]).astype(self.dtype) self.x = np.random.uniform(0.1, 1, [121]).astype(self.dtype)
...@@ -64,6 +85,15 @@ class DotOpBatch(DotOp): ...@@ -64,6 +85,15 @@ class DotOpBatch(DotOp):
[11, 12]) [11, 12])
self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1]) self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1])
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
class TestDotOpError(unittest.TestCase): class TestDotOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
......
...@@ -76,7 +76,10 @@ class SimpleNet(fluid.Layer): ...@@ -76,7 +76,10 @@ class SimpleNet(fluid.Layer):
class TestDygraphSimpleNet(unittest.TestCase): class TestDygraphSimpleNet(unittest.TestCase):
def test_simple_net(self): def test_simple_net(self):
for is_sparse in [True, False]: for is_sparse in [True, False]:
for dtype in ["float32", "float64"]: dtype_list = ["float32"]
if not core.is_compiled_with_rocm():
dtype_list.append("float64")
for dtype in dtype_list:
self.simple_net_float32(is_sparse, dtype) self.simple_net_float32(is_sparse, dtype)
def simple_net_float32(self, is_sparse, dtype): def simple_net_float32(self, is_sparse, dtype):
......
...@@ -82,7 +82,10 @@ class SimpleNet(fluid.Layer): ...@@ -82,7 +82,10 @@ class SimpleNet(fluid.Layer):
class TestDygraphSimpleNet(unittest.TestCase): class TestDygraphSimpleNet(unittest.TestCase):
def test_simple_net(self): def test_simple_net(self):
for is_sparse in [True, False]: for is_sparse in [True, False]:
for dtype in ["float32", "float64"]: dtype_list = ["float32"]
if not core.is_compiled_with_rocm():
dtype_list.append("float64")
for dtype in dtype_list:
self.simple_net_float(is_sparse, dtype) self.simple_net_float(is_sparse, dtype)
def simple_net_float(self, is_sparse, dtype): def simple_net_float(self, is_sparse, dtype):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册