diff --git a/paddle/fluid/operators/bincount_op.cc b/paddle/fluid/operators/bincount_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8b2fa60f8722e55fed718f8540633a57a0866f6b --- /dev/null +++ b/paddle/fluid/operators/bincount_op.cc @@ -0,0 +1,116 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/bincount_op.h" + +#include +#include +#include + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class BincountOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::InvalidArgument( + "Input(X) of BincountOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + platform::errors::InvalidArgument( + "Output(Out) of BincountOp should not be null.")); + + auto input_dim = ctx->GetInputDim("X"); + auto minlength = ctx->Attrs().Get("minlength"); + + PADDLE_ENFORCE_GE(minlength, 0, + platform::errors::InvalidArgument( + "The minlength should be greater than or equal to 0." + "But received minlength is %d", + minlength)); + + PADDLE_ENFORCE_EQ(input_dim.size(), 1, + platform::errors::InvalidArgument( + "The 'shape' of Input(X) must be 1-D tensor." + "But the dimension of Input(X) is [%d]", + input_dim.size())); + + if (ctx->HasInput("Weights")) { + auto weights_dim = ctx->GetInputDim("Weights"); + PADDLE_ENFORCE_EQ(weights_dim.size(), 1, + platform::errors::InvalidArgument( + "The 'shape' of Input(Weights) must be 1-D tensor." + "But the dimension of Input(Weights) is [%d]", + weights_dim.size())); + + PADDLE_ENFORCE_EQ( + weights_dim[0], input_dim[0], + platform::errors::InvalidArgument( + "The 'shape' of Input(Weights) must be equal to the 'shape' of " + "Input(X)." + "But received: the 'shape' of Input(Weights) is [%s]," + "the 'shape' of Input(X) is [%s]", + weights_dim, input_dim)); + } + + ctx->SetOutputDim("Out", framework::make_ddim({-1})); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const { + auto data_type = + ctx.HasInput("Weights") + ? OperatorWithKernel::IndicateVarDataType(ctx, "Weights") + : OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(data_type, ctx.device_context()); + } +}; + +class BincountOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) The input tensor of Bincount op,"); + AddInput("Weights", "(Tensor) The weights tensor of Bincount op,") + .AsDispensable(); + AddOutput("Out", "(Tensor) The output tensor of Bincount op,"); + AddAttr("minlength", "(int) The minimal numbers of bins") + .SetDefault(0) + .EqualGreaterThan(0); + AddComment(R"DOC( + Bincount Operator. + Computes frequency of each value in the input tensor. + Elements of input tensor should be non-negative ints. + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + bincount, ops::BincountOp, ops::BincountOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + bincount, ops::BincountKernel, + ops::BincountKernel, + ops::BincountKernel, + ops::BincountKernel); diff --git a/paddle/fluid/operators/bincount_op.cu b/paddle/fluid/operators/bincount_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..757f72862910695503d0f7e8c6950b5ee84c105f --- /dev/null +++ b/paddle/fluid/operators/bincount_op.cu @@ -0,0 +1,160 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/bincount_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_launch_config.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using platform::PADDLE_CUDA_NUM_THREADS; + +inline int GET_BLOCKS(const int N) { + return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; +} + +template +__global__ void KernelBincount(const InputT* input, const int total_elements, + const bool has_weights, const T* weights, + OutT* output) { + if (!has_weights) { + for (int i = threadIdx.x; i < total_elements; i += blockDim.x) { + paddle::platform::CudaAtomicAdd(&output[input[i]], 1L); + } + } else { + for (int i = threadIdx.x; i < total_elements; i += blockDim.x) { + paddle::platform::CudaAtomicAdd(&output[input[i]], + static_cast(weights[i])); + } + } +} + +template +void BincountCUDAInner(const framework::ExecutionContext& context) { + const Tensor* input = context.Input("X"); + const Tensor* weights = context.Input("Weights"); + Tensor* output = context.Output("Out"); + auto& minlength = context.Attr("minlength"); + + const InputT* input_data = input->data(); + + const int input_numel = input->numel(); + + if (input_data == nullptr) { + framework::DDim out_dim{0}; + output->Resize(out_dim); + output->mutable_data(context.GetPlace()); + return; + } + auto input_x = framework::EigenVector::Flatten(*input); + + framework::Tensor input_min_t, input_max_t; + auto* input_max_data = + input_max_t.mutable_data({1}, context.GetPlace()); + auto* input_min_data = + input_min_t.mutable_data({1}, context.GetPlace()); + + auto input_max_scala = framework::EigenScalar::From(input_max_t); + auto input_min_scala = framework::EigenScalar::From(input_min_t); + + auto* place = context.template device_context().eigen_device(); + input_max_scala.device(*place) = input_x.maximum(); + input_min_scala.device(*place) = input_x.minimum(); + + Tensor input_min_cpu, input_max_cpu; + TensorCopySync(input_max_t, platform::CPUPlace(), &input_max_cpu); + TensorCopySync(input_min_t, platform::CPUPlace(), &input_min_cpu); + + InputT input_min = input_min_cpu.data()[0]; + + PADDLE_ENFORCE_GE( + input_min, static_cast(0), + platform::errors::InvalidArgument( + "The elements in input tensor must be non-negative ints")); + + int64_t output_size = + static_cast(input_max_cpu.data()[0]) + 1L; + + output_size = std::max(output_size, static_cast(minlength)); + framework::DDim out_dim{output_size}; + output->Resize(out_dim); + + bool has_weights = (weights != nullptr); + + const T* weights_data = has_weights ? weights->data() : nullptr; + + auto stream = + context.template device_context().stream(); + + if (!has_weights) { + int64_t* output_data = output->mutable_data(context.GetPlace()); + math::SetConstant()( + context.template device_context(), output, 0L); + + KernelBincount<<>>( + input_data, input_numel, has_weights, weights_data, output_data); + } else { + const auto& weights_type = weights->type(); + + if (weights_type == framework::proto::VarType::FP32) { + float* output_data = output->mutable_data(context.GetPlace()); + math::SetConstant()( + context.template device_context(), output, + static_cast(0)); + + KernelBincount<<>>( + input_data, input_numel, has_weights, weights_data, output_data); + } else { + double* output_data = output->mutable_data(context.GetPlace()); + math::SetConstant()( + context.template device_context(), output, + static_cast(0)); + + KernelBincount<<>>( + input_data, input_numel, has_weights, weights_data, output_data); + } + } +} + +template +class BincountCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("X"); + const auto& input_type = input->type(); + + if (input_type == framework::proto::VarType::INT32) { + BincountCUDAInner(context); + } else if (input_type == framework::proto::VarType::INT64) { + BincountCUDAInner(context); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + bincount, ops::BincountCUDAKernel, + ops::BincountCUDAKernel, + ops::BincountCUDAKernel, + ops::BincountCUDAKernel); diff --git a/paddle/fluid/operators/bincount_op.h b/paddle/fluid/operators/bincount_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a142332bce2669987af5923cc879f563d4523bf6 --- /dev/null +++ b/paddle/fluid/operators/bincount_op.h @@ -0,0 +1,109 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +void BincountInner(const framework::ExecutionContext& context) { + const Tensor* input = context.Input("X"); + const Tensor* weights = context.Input("Weights"); + Tensor* output = context.Output("Out"); + auto& minlength = context.Attr("minlength"); + + const InputT* input_data = input->data(); + + auto input_numel = input->numel(); + + if (input_data == nullptr) { + framework::DDim out_dim{0}; + output->Resize(out_dim); + output->mutable_data(context.GetPlace()); + return; + } + + PADDLE_ENFORCE_GE( + *std::min_element(input_data, input_data + input_numel), + static_cast(0), + platform::errors::InvalidArgument( + "The elements in input tensor must be non-negative ints")); + + int64_t output_size = static_cast(*std::max_element( + input_data, input_data + input_numel)) + + 1L; + output_size = std::max(output_size, static_cast(minlength)); + + framework::DDim out_dim{output_size}; + output->Resize(out_dim); + + bool has_weights = (weights != nullptr); + + if (has_weights) { + const T* weights_data = weights->data(); + const auto& weights_type = weights->type(); + if (weights_type == framework::proto::VarType::FP32) { + float* output_data = output->mutable_data(context.GetPlace()); + math::SetConstant()( + context.template device_context(), output, + static_cast(0)); + for (int64_t i = 0; i < input_numel; i++) { + output_data[input_data[i]] += static_cast(weights_data[i]); + } + } else { + double* output_data = output->mutable_data(context.GetPlace()); + math::SetConstant()( + context.template device_context(), output, + static_cast(0)); + for (int64_t i = 0; i < input_numel; i++) { + output_data[input_data[i]] += static_cast(weights_data[i]); + } + } + + } else { + int64_t* output_data = output->mutable_data(context.GetPlace()); + math::SetConstant()( + context.template device_context(), output, 0L); + for (int64_t i = 0; i < input_numel; i++) { + output_data[input_data[i]] += 1L; + } + } +} + +template +class BincountKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("X"); + const auto& input_type = input->type(); + + if (input_type == framework::proto::VarType::INT32) { + BincountInner(context); + } else if (input_type == framework::proto::VarType::INT64) { + BincountInner(context); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 1e1f195c5c617112871780df21fddc8de4278072..54ea0f2aee17f90dae791ffd47e8e1d7f673de7c 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -40,6 +40,7 @@ // need to manually specify them in this map. std::map> op_ins_map = { {"layer_norm", {"X", "Scale", "Bias"}}, + {"bincount", {"X", "Weights"}}, {"fused_attention", {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "SrcMask", "OutLinearW", "OutLinearBias", "Ln2Scale", "Ln2Bias"}}, diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 2051a4f6fcd50dab42f5e542807d8165e4c5ad2f..471f6f395351ec559069626085b9f76e8f2b3497 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -98,6 +98,7 @@ from .tensor.linalg import cross # noqa: F401 from .tensor.linalg import cholesky # noqa: F401 from .tensor.linalg import bmm # noqa: F401 from .tensor.linalg import histogram # noqa: F401 +from .tensor.linalg import bincount # noqa: F401 from .tensor.linalg import mv # noqa: F401 from .tensor.logic import equal # noqa: F401 from .tensor.logic import greater_equal # noqa: F401 @@ -398,6 +399,7 @@ __all__ = [ # noqa 'bitwise_not', 'mm', 'flip', + 'bincount', 'histogram', 'multiplex', 'CUDAPlace', diff --git a/python/paddle/fluid/tests/unittests/test_bincount_op.py b/python/paddle/fluid/tests/unittests/test_bincount_op.py new file mode 100644 index 0000000000000000000000000000000000000000..851bf7b01125a3de1d1f442de529f2b9a76e31a9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_bincount_op.py @@ -0,0 +1,205 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard +from op_test import OpTest + +paddle.enable_static() + + +class TestBincountOpAPI(unittest.TestCase): + """Test bincount api.""" + + def test_static_graph(self): + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + inputs = fluid.data(name='input', dtype='int64', shape=[7]) + weights = fluid.data(name='weights', dtype='int64', shape=[7]) + output = paddle.bincount(inputs, weights=weights) + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup_program) + img = np.array([0, 1, 1, 3, 2, 1, 7]).astype(np.int64) + w = np.array([0, 1, 1, 2, 2, 1, 0]).astype(np.int64) + res = exe.run(train_program, + feed={'input': img, + 'weights': w}, + fetch_list=[output]) + actual = np.array(res[0]) + expected = np.bincount(img, weights=w) + self.assertTrue( + (actual == expected).all(), + msg='bincount output is wrong, out =' + str(actual)) + + def test_dygraph(self): + with fluid.dygraph.guard(): + inputs_np = np.array([0, 1, 1, 3, 2, 1, 7]).astype(np.int64) + inputs = fluid.dygraph.to_variable(inputs_np) + actual = paddle.bincount(inputs) + expected = np.bincount(inputs) + self.assertTrue( + (actual.numpy() == expected).all(), + msg='bincount output is wrong, out =' + str(actual.numpy())) + + +class TestBincountOpError(unittest.TestCase): + """Test bincount op error.""" + + def run_network(self, net_func): + with fluid.dygraph.guard(): + net_func() + + def test_input_value_error(self): + """Test input tensor should be non-negative.""" + + def net_func(): + input_value = paddle.to_tensor([1, 2, 3, 4, -5]) + paddle.bincount(input_value) + + with self.assertRaises(ValueError): + self.run_network(net_func) + + def test_input_shape_error(self): + """Test input tensor should be 1-D tansor.""" + + def net_func(): + input_value = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]) + paddle.bincount(input_value) + + with self.assertRaises(ValueError): + self.run_network(net_func) + + def test_minlength_value_error(self): + """Test minlength is non-negative ints.""" + + def net_func(): + input_value = paddle.to_tensor([1, 2, 3, 4, 5]) + paddle.bincount(input_value, minlength=-1) + + with self.assertRaises(IndexError): + self.run_network(net_func) + + def test_input_type_errors(self): + """Test input tensor should only contain non-negative ints.""" + + def net_func(): + input_value = paddle.to_tensor([1., 2., 3., 4., 5.]) + paddle.bincount(input_value) + + with self.assertRaises(TypeError): + self.run_network(net_func) + + def test_weights_shape_error(self): + """Test weights tensor should have the same shape as input tensor.""" + + def net_func(): + input_value = paddle.to_tensor([1, 2, 3, 4, 5]) + weights = paddle.to_tensor([1, 1, 1, 1, 1, 1]) + paddle.bincount(input_value, weights=weights) + + with self.assertRaises(ValueError): + self.run_network(net_func) + + +class TestBincountOp(OpTest): + # without weights + def setUp(self): + self.op_type = "bincount" + self.init_test_case() + self.inputs = {"X": self.np_input} + self.attrs = {"minlength": self.minlength} + self.outputs = {"Out": self.Out} + + def init_test_case(self): + self.minlength = 0 + self.np_input = np.random.randint(low=0, high=20, size=10) + self.Out = np.bincount(self.np_input, minlength=self.minlength) + + def test_check_output(self): + self.check_output() + + +class TestCase1(TestBincountOp): + # with weights(FLOAT32) + def setUp(self): + self.op_type = "bincount" + self.init_test_case() + self.inputs = {"X": self.np_input, "Weights": self.np_weights} + self.attrs = {"minlength": self.minlength} + self.outputs = {"Out": self.Out} + + def init_test_case(self): + self.minlength = 0 + self.np_weights = np.random.randint( + low=0, high=20, size=10).astype(np.float32) + self.np_input = np.random.randint(low=0, high=20, size=10) + self.Out = np.bincount( + self.np_input, weights=self.np_weights, + minlength=self.minlength).astype(np.float32) + + +class TestCase2(TestBincountOp): + # with weights(other) + def setUp(self): + self.op_type = "bincount" + self.init_test_case() + self.inputs = {"X": self.np_input, "Weights": self.np_weights} + self.attrs = {"minlength": self.minlength} + self.outputs = {"Out": self.Out} + + def init_test_case(self): + self.minlength = 0 + self.np_weights = np.random.randint(low=0, high=20, size=10) + self.np_input = np.random.randint(low=0, high=20, size=10) + self.Out = np.bincount( + self.np_input, weights=self.np_weights, minlength=self.minlength) + + +class TestCase3(TestBincountOp): + # empty input + def init_test_case(self): + self.minlength = 0 + self.np_input = np.array([], dtype=np.int64) + self.Out = np.bincount(self.np_input, minlength=self.minlength) + + +class TestCase4(TestBincountOp): + # with input(INT32) + def init_test_case(self): + self.minlength = 0 + self.np_input = np.random.randint( + low=0, high=20, size=10).astype(np.int32) + self.Out = np.bincount(self.np_input, minlength=self.minlength) + + +class TestCase5(TestBincountOp): + # with minlength greater than max(X) + def init_test_case(self): + self.minlength = 20 + self.np_input = np.random.randint(low=0, high=10, size=10) + self.Out = np.bincount(self.np_input, minlength=self.minlength) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b898b60fe47126b6d157b0f4fb35806a35b7844d..f528714e9164a40943e6ecefa83938059be3b992 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -44,6 +44,7 @@ from .linalg import cross # noqa: F401 from .linalg import cholesky # noqa: F401 from .linalg import bmm # noqa: F401 from .linalg import histogram # noqa: F401 +from .linalg import bincount # noqa: F401 from .linalg import mv # noqa: F401 from .linalg import eig # noqa: F401 from .linalg import matrix_power # noqa: F401 @@ -236,6 +237,7 @@ tensor_method_func = [ #noqa 'cholesky', 'bmm', 'histogram', + 'bincount', 'mv', 'matrix_power', 'qr', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 6853d904adbf6e2b8227549332e72eec5a4e31cb..aea56432fa9cab1bbc888670060cc6a501132467 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1293,6 +1293,59 @@ def histogram(input, bins=100, min=0, max=0, name=None): return out +def bincount(x, weights=None, minlength=0, name=None): + """ + Computes frequency of each value in the input tensor. + + Args: + x (Tensor): A Tensor with non-negative integer. Should be 1-D tensor. + weights (Tensor, optional): Weight for each value in the input tensor. Should have the same shape as input. Default is None. + minlength (int, optional): Minimum number of bins. Should be non-negative integer. Default is 0. + name(str, optional): The default value is None. Normally there is no need for user to set this + property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: The tensor of frequency. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([1, 2, 1, 4, 5]) + result1 = paddle.bincount(x) + print(result1) # [0, 2, 1, 0, 1, 1] + + w = paddle.to_tensor([2.1, 0.4, 0.1, 0.5, 0.5]) + result2 = paddle.bincount(x, weights=w) + print(result2) # [0., 2.19999981, 0.40000001, 0., 0.50000000, 0.50000000] + """ + if x.dtype not in [paddle.int32, paddle.int64]: + raise TypeError("Elements in Input(x) should all be integers") + + if in_dygraph_mode(): + return _C_ops.bincount(x, weights, "minlength", minlength) + + helper = LayerHelper('bincount', **locals()) + + check_variable_and_dtype(x, 'X', ['int32', 'int64'], 'bincount') + + if weights is not None: + check_variable_and_dtype(weights, 'Weights', + ['int32', 'int64', 'float32', 'float64'], + 'bincount') + out = helper.create_variable_for_type_inference(dtype=weights.dtype) + else: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='bincount', + inputs={'X': x, + 'Weights': weights}, + outputs={'Out': out}, + attrs={'minlength': minlength}) + return out + + def mv(x, vec, name=None): """ Performs a matrix-vector product of the matrix x and the vector vec.