未验证 提交 028de857 编写于 作者: L Leo Chen 提交者: GitHub

fix dtype error of compare op, test=develop (#25059)

上级 9ed16a43
......@@ -15,10 +15,12 @@ limitations under the License. */
#pragma once
#include <glog/logging.h>
#include <algorithm>
#include <functional> // for multiplies
#include <iterator>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
......@@ -30,6 +32,7 @@ limitations under the License. */
#ifdef __NVCC__
#include <cuda.h>
#include <thrust/iterator/iterator_adaptor.h>
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
......@@ -194,11 +197,11 @@ void CommonForwardBroadcastCPU(const framework::Tensor *x,
}
#ifdef __NVCC__
template <typename Functor, typename T>
template <typename Functor, typename T, typename OutType = T>
__global__ void CommonForwardBroadcastCUDAKernel(
const int *x_strides_array, const int *y_strides_array,
const int *out_dims_array, const T *x, const T *y, T *out, int out_size,
int max_dim, Functor func, const bool is_xsize_larger) {
const int *out_dims_array, const T *x, const T *y, OutType *out,
int out_size, int max_dim, Functor func, const bool is_xsize_larger) {
for (int out_index = blockIdx.x * blockDim.x + threadIdx.x;
out_index < out_size; out_index += blockDim.x * gridDim.x) {
int x_index = 0;
......@@ -220,7 +223,7 @@ __global__ void CommonForwardBroadcastCUDAKernel(
}
}
template <typename Functor, typename T>
template <typename Functor, typename T, typename OutType = T>
void CommonForwardBroadcastCUDA(
const framework::Tensor *x, const framework::Tensor *y,
framework::Tensor *z, int *x_dims_array, int *y_dims_array,
......@@ -230,7 +233,7 @@ void CommonForwardBroadcastCUDA(
auto cplace = platform::CPUPlace();
const T *x_data = x->data<T>();
const T *y_data = y->data<T>();
T *out_data = z->mutable_data<T>(ctx.GetPlace());
OutType *out_data = z->mutable_data<OutType>(ctx.GetPlace());
std::vector<int> x_strides_array(max_dim);
std::vector<int> y_strides_array(max_dim);
......@@ -268,7 +271,7 @@ void CommonForwardBroadcastCUDA(
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
CommonForwardBroadcastCUDAKernel<
Functor, T><<<gird_size, block_size, 0, ctx.stream()>>>(
Functor, T, OutType><<<gird_size, block_size, 0, ctx.stream()>>>(
x_strides_array_gpu, y_strides_array_gpu, out_dims_array_gpu, x_data,
y_data, out_data, out_size, max_dim, func, is_xsize_larger);
}
......@@ -1796,7 +1799,7 @@ void CommonElementwiseBroadcastForward(
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef __NVCC__
CommonForwardBroadcastCUDA<Functor, T>(
CommonForwardBroadcastCUDA<Functor, T, OutType>(
x, y, z, x_dims_array.data(), y_dims_array.data(),
out_dims_array.data(), max_dim,
ctx.template device_context<platform::CUDADeviceContext>(), func,
......
......@@ -273,6 +273,16 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
self.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))
self.assertTrue(np.array_equal(res1.numpy(), res3.numpy()))
def test_conpare_op_broadcast(self):
a_np = np.random.uniform(-1, 1, [10, 1, 10]).astype(self.dtype)
b_np = np.random.uniform(-1, 1, [1, 1, 10]).astype(self.dtype)
with fluid.dygraph.guard():
a = fluid.dygraph.to_variable(a_np)
b = fluid.dygraph.to_variable(b_np)
self.assertEqual((a != b).dtype, fluid.core.VarDesc.VarType.BOOL)
self.assertTrue(np.array_equal((a != b).numpy(), a_np != b_np))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册