未验证 提交 d01f6269 编写于 作者: F furnace 提交者: GitHub

update mv op according PR#27024 (#27474)

上级 9d783aed
...@@ -42,21 +42,21 @@ class MVOp : public framework::OperatorWithKernel { ...@@ -42,21 +42,21 @@ class MVOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "mv"); OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "mv");
auto dim_x = context->GetInputDim("X"); auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Vec"); auto dim_vec = context->GetInputDim("Vec");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim_x.size(), 2, dim_x.size(), 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The rank of input X should be 2, but is %d", dim_x.size())); "The rank of input X should be 2, but is %d", dim_x.size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim_y.size(), 1, dim_vec.size(), 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The rank of input Vec should be 1, but is %d", dim_y.size())); "The rank of input Vec should be 1, but is %d", dim_vec.size()));
PADDLE_ENFORCE_EQ(dim_x[1] == dim_y[0], true, PADDLE_ENFORCE_EQ(dim_x[1], dim_vec[0],
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The length of input X' second dim should equal the " "X's second dimension is expected to be equal to "
"length of input Vec," "Vec's first dimension"
" but X[%d, %d], Vec[%d]", "but recieved X'shape = [%s], Vec's shape = [%s]",
dim_x[0], dim_x[1], dim_y[0])); dim_x, dim_vec));
framework::DDim dim_out = framework::make_ddim({dim_x[0]}); framework::DDim dim_out = framework::make_ddim({dim_x[0]});
......
...@@ -19,8 +19,8 @@ namespace paddle { ...@@ -19,8 +19,8 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
__global__ void MVGradCUDAKernel(const int m, const int n, const T *dout, __global__ void MVGradDxCUDAKernel(const int m, const int n, const T *dout,
const T *vec, T *dx) { const T *vec, T *dx) {
int idx = blockDim.x * blockIdx.x + threadIdx.x; int idx = blockDim.x * blockIdx.x + threadIdx.x;
for (; idx < m * n; idx += blockDim.x * gridDim.x) { for (; idx < m * n; idx += blockDim.x * gridDim.x) {
int i = idx / n; int i = idx / n;
...@@ -52,32 +52,31 @@ class MVGradKernel<platform::CUDADeviceContext, T> ...@@ -52,32 +52,31 @@ class MVGradKernel<platform::CUDADeviceContext, T>
int m = dim_x[0]; int m = dim_x[0];
int n = dim_x[1]; int n = dim_x[1];
dx->Resize(framework::make_ddim({m * n}));
// get data ptr // get data ptr
const T *x_data = x->data<T>(); const T *x_data = x->data<T>();
const T *vec_data = vec->data<T>(); const T *vec_data = vec->data<T>();
const T *dout_data = dout->data<T>(); const T *dout_data = dout->data<T>();
T *dx_data = dx->mutable_data<T>(context.GetPlace());
T *dvec_data = dvec->mutable_data<T>(context.GetPlace());
auto &dev_ctx = auto &dev_ctx =
context.template device_context<platform::CUDADeviceContext>(); context.template device_context<platform::CUDADeviceContext>();
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx); auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
// calculate dx
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
auto config = GetGpuLaunchConfig1D(dev_ctx, m * n); auto config = GetGpuLaunchConfig1D(dev_ctx, m * n);
MVGradCUDAKernel<
T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
m, n, dout_data, vec_data, dx_data);
dx->Resize(framework::make_ddim({m, n})); if (dx) {
T *dx_data = dx->mutable_data<T>(context.GetPlace());
MVGradDxCUDAKernel<
T><<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
m, n, dout_data, vec_data, dx_data);
}
if (dvec) {
T *dvec_data = dvec->mutable_data<T>(context.GetPlace());
// calculate dvec blas.GEMV(true, dim_x[0], dim_x[1], static_cast<T>(1), x_data, dout_data,
blas.GEMV(true, dim_x[0], dim_x[1], static_cast<T>(1), x_data, dout_data, static_cast<T>(0), dvec_data);
static_cast<T>(0), dvec_data); }
} }
}; };
......
...@@ -74,30 +74,30 @@ class MVGradKernel : public framework::OpKernel<T> { ...@@ -74,30 +74,30 @@ class MVGradKernel : public framework::OpKernel<T> {
int m = dim_x[0]; int m = dim_x[0];
int n = dim_x[1]; int n = dim_x[1];
dx->Resize(framework::make_ddim({m * n}));
// get data ptr // get data ptr
const T *x_data = x->data<T>(); const T *x_data = x->data<T>();
const T *vec_data = vec->data<T>(); const T *vec_data = vec->data<T>();
const T *dout_data = dout->data<T>(); const T *dout_data = dout->data<T>();
T *dx_data = dx->mutable_data<T>(context.GetPlace()); if (dx) {
T *dvec_data = dvec->mutable_data<T>(context.GetPlace()); T *dx_data = dx->mutable_data<T>(context.GetPlace());
auto &dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
// calculate dx for (int i = 0; i < m; ++i) {
for (int i = 0; i < m; ++i) { for (int j = 0; j < n; ++j) {
for (int j = 0; j < n; ++j) dx_data[i * n + j] = dout_data[i] * vec_data[j];
dx_data[i * n + j] = dout_data[i] * vec_data[j]; }
}
} }
dx->Resize(framework::make_ddim({m, n})); if (dvec) {
T *dvec_data = dvec->mutable_data<T>(context.GetPlace());
auto &dev_ctx = context.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
// calculate dvec blas.GEMV(true, dim_x[0], dim_x[1], static_cast<T>(1), x_data, dout_data,
blas.GEMV(true, dim_x[0], dim_x[1], static_cast<T>(1), x_data, dout_data, static_cast<T>(0), dvec_data);
static_cast<T>(0), dvec_data); }
} }
}; };
......
...@@ -20,6 +20,7 @@ import paddle ...@@ -20,6 +20,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.static import program_guard, Program
from op_test import OpTest from op_test import OpTest
...@@ -37,7 +38,7 @@ class TestMVOp(OpTest): ...@@ -37,7 +38,7 @@ class TestMVOp(OpTest):
self.check_grad(['X', 'Vec'], 'Out') self.check_grad(['X', 'Vec'], 'Out')
def init_config(self): def init_config(self):
self.x = np.random.random((5, 100)).astype("float64") self.x = np.random.random((2, 100)).astype("float64")
self.vec = np.random.random((100)).astype("float64") self.vec = np.random.random((100)).astype("float64")
...@@ -57,21 +58,36 @@ class TestMVAPI(unittest.TestCase): ...@@ -57,21 +58,36 @@ class TestMVAPI(unittest.TestCase):
paddle.enable_static() paddle.enable_static()
def test_static_graph(self): def test_static_graph(self):
paddle.enable_static() for x_stop_gradient in [False, True]:
for vec_stop_gradient in [False, True]:
paddle.enable_static()
train_program = Program()
startup_program = Program()
self.input_x = np.random.rand(5, 100).astype("float64")
self.input_vec = np.random.rand(100).astype("float64")
with program_guard(train_program, startup_program):
data_x = paddle.static.data(
"x", shape=[5, 100], dtype="float64")
data_vec = paddle.static.data(
"vec", shape=[100], dtype="float64")
data_x.stop_gradient = x_stop_gradient
data_vec.stop_gradient = vec_stop_gradient
result_vec = paddle.mv(data_x, data_vec)
self.input_x = np.random.rand(5, 100).astype("float64") self.place = paddle.CPUPlace()
self.input_vec = np.random.rand(100).astype("float64") exe = paddle.static.Executor(self.place)
res, = exe.run(
data_x = paddle.static.data("x", shape=[5, 100], dtype="float64") feed={"x": self.input_x,
data_vec = paddle.static.data("vec", shape=[100], dtype="float64") "vec": self.input_vec},
result_vec = paddle.mv(data_x, data_vec) fetch_list=[result_vec])
self.place = paddle.CPUPlace() z_expected = np.array(np.dot(self.input_x, self.input_vec))
exe = paddle.static.Executor(self.place) self.assertTrue(np.allclose(res, z_expected))
res, = exe.run(feed={"x": self.input_x,
"vec": self.input_vec},
fetch_list=[result_vec])
z_expected = np.array(np.dot(self.input_x, self.input_vec))
self.assertTrue(np.allclose(res, z_expected))
class TestMVError(unittest.TestCase): class TestMVError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册