From 2fd728a97860b3126bc3f7302269d9dd21736fdc Mon Sep 17 00:00:00 2001 From: liuwei1031 <46661762+liuwei1031@users.noreply.github.com> Date: Sun, 12 Apr 2020 09:46:40 +0800 Subject: [PATCH] add new dot op(#23418) --- paddle/fluid/framework/ddim.cc | 10 ++ paddle/fluid/framework/ddim.h | 3 + paddle/fluid/operators/dot_op.cc | 160 +++++++++++++++++ paddle/fluid/operators/dot_op.cu | 28 +++ paddle/fluid/operators/dot_op.h | 168 ++++++++++++++++++ python/paddle/__init__.py | 2 +- .../fluid/tests/unittests/test_dot_op.py | 105 +++++++++++ .../white_list/no_grad_set_white_list.py | 1 + python/paddle/tensor/__init__.py | 3 +- python/paddle/tensor/linalg.py | 52 +++++- 10 files changed, 527 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/operators/dot_op.cc create mode 100644 paddle/fluid/operators/dot_op.cu create mode 100644 paddle/fluid/operators/dot_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_dot_op.py diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index b3aaa01d53b..1dae5e12a8c 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -48,6 +48,16 @@ bool DDim::operator==(const DDim& d) const { bool DDim::operator!=(const DDim& d) const { return !(*this == d); } +std::string DDim::to_str() const { + std::stringstream ss; + ss << '['; + if (rank_ > 0) ss << dim_[0]; + + for (int i = 1; i < rank_; ++i) ss << ", " << dim_[i]; + ss << ']'; + return ss.str(); +} + struct ProductVisitor { template inline int64_t operator()(const Dim& dim) { diff --git a/paddle/fluid/framework/ddim.h b/paddle/fluid/framework/ddim.h index 14824afbea7..2f04c428e44 100644 --- a/paddle/fluid/framework/ddim.h +++ b/paddle/fluid/framework/ddim.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include "paddle/fluid/framework/dim.h" @@ -123,6 +124,8 @@ class DDim { inline int size() const { return rank_; } + std::string to_str() const; + private: template inline Dim& UnsafeCast() { diff --git a/paddle/fluid/operators/dot_op.cc b/paddle/fluid/operators/dot_op.cc new file mode 100644 index 00000000000..0527445adf0 --- /dev/null +++ b/paddle/fluid/operators/dot_op.cc @@ -0,0 +1,160 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/dot_op.h" + +namespace paddle { +namespace operators { + +class DotOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ(true, ctx->HasInput("X"), + platform::errors::PreconditionNotMet( + "Input(X) of DotOp should not be null.")); + PADDLE_ENFORCE_EQ(true, ctx->HasInput("Y"), + platform::errors::PreconditionNotMet( + "Input(Y) of DotOp should not be null.")); + PADDLE_ENFORCE_EQ(true, ctx->HasOutput("Out"), + platform::errors::PreconditionNotMet( + "Output(Out) of DotOp should not be null.")); + + auto x_dims = ctx->GetInputDim("X"); + auto x_rank = (size_t)x_dims.size(); + PADDLE_ENFORCE_EQ(true, 1 == x_rank || 2 == x_rank, + platform::errors::PreconditionNotMet( + "ShapeError: The dimensions of input tensor X (%s) " + "should be 1 or 2", + x_dims.to_str())); + + auto y_dims = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_EQ( + true, x_rank == (size_t)y_dims.size(), + platform::errors::PreconditionNotMet( + "ShapeError: The shape of input tensor Y: %s should match with " + "input tenosr X: %s", + y_dims.to_str(), x_dims.to_str())); + bool shape_match = true; + for (size_t i = 0; i < x_rank; ++i) { + if (x_dims[i] != y_dims[i]) { + shape_match = false; + break; + } + } + + PADDLE_ENFORCE_EQ(true, shape_match, + platform::errors::PreconditionNotMet( + "ShapeError: The shape of input tensor X: %s should " + "be exactly the same " + "with input tensor Y: %s", + x_dims.to_str(), y_dims.to_str())); + auto dims = vectorize(x_dims); + dims[dims.size() - 1] = 1; + ctx->SetOutputDim("Out", framework::make_ddim(dims)); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class DotOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() final { + AddInput("X", "(Tensor) The first input tensor. "); + AddInput("Y", "(Tensor) The second input tensor. "); + AddOutput("Out", "(Tensor) The result tensor."); + AddComment(""); + } +}; + +class DotGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + true, ctx->HasInput("X"), + platform::errors::PreconditionNotMet("Input(X) should not be null.")); + PADDLE_ENFORCE_EQ( + true, ctx->HasInput("Y"), + platform::errors::PreconditionNotMet("Input(Y) should not be null.")); + PADDLE_ENFORCE_EQ(true, ctx->HasInput(framework::GradVarName("Out")), + platform::errors::PreconditionNotMet( + "Input(Out@GRAD) should not be null.")); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->ShareDim("X", /*->*/ x_grad_name); + ctx->ShareLoD("X", /*->*/ x_grad_name); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->ShareDim("Y", /*->*/ y_grad_name); + ctx->ShareLoD("Y", /*->*/ y_grad_name); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.GetPlace()); + } +}; + +template +class DotOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("dot_grad"); + + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetAttrMap(this->Attrs()); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(dot, ops::DotOp, ops::DotOpMaker, + ops::DotOpGradMaker, + ops::DotOpGradMaker); + +REGISTER_OPERATOR(dot_grad, ops::DotGradOp); + +REGISTER_OP_CPU_KERNEL( + dot, ops::DotKernel, + ops::DotKernel, + ops::DotKernel, + ops::DotKernel); +REGISTER_OP_CPU_KERNEL( + dot_grad, ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel); diff --git a/paddle/fluid/operators/dot_op.cu b/paddle/fluid/operators/dot_op.cu new file mode 100644 index 00000000000..eb7ebbe32d7 --- /dev/null +++ b/paddle/fluid/operators/dot_op.cu @@ -0,0 +1,28 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/dot_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(dot, ops::DotKernel, + ops::DotKernel, + ops::DotKernel, + ops::DotKernel); +REGISTER_OP_CUDA_KERNEL(dot_grad, + ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel, + ops::DotGradKernel); diff --git a/paddle/fluid/operators/dot_op.h b/paddle/fluid/operators/dot_op.h new file mode 100644 index 00000000000..2580b00d7c2 --- /dev/null +++ b/paddle/fluid/operators/dot_op.h @@ -0,0 +1,168 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenMatrix = framework::EigenMatrix; + +template +class DotKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* tensor_x = ctx.Input("X"); + auto* tensor_y = ctx.Input("Y"); + auto* tensor_out = ctx.Output("Out"); + tensor_out->mutable_data(ctx.GetPlace()); + +#ifdef __NVCC__ + if (1 == tensor_out->dims().size()) { + auto out = framework::EigenScalar::From(*tensor_out); + auto x = framework::EigenVector::Flatten(*tensor_x); + auto y = framework::EigenVector::Flatten(*tensor_y); + + auto& dev = *ctx.template device_context().eigen_device(); + out.device(dev) = (x * y).sum(); + } else { + auto out = EigenMatrix::From(*tensor_out); + auto x = EigenMatrix::From(*tensor_x); + auto y = EigenMatrix::From(*tensor_y); + + auto& dev = *ctx.template device_context().eigen_device(); + out.device(dev) = (x * y).sum(Eigen::DSizes(1)); + } +#else + const auto* data_x = tensor_x->data(); + const auto* data_y = tensor_y->data(); + auto* data_out = tensor_out->data(); + + auto x_dims = tensor_x->dims(); + auto step = x_dims[x_dims.size() - 1]; + int size = static_cast(framework::product(x_dims)); + + for (int ind = -1, j = 0; j < size; ++j) { + if (j % step == 0) { + ++ind; + data_out[ind] = data_x[j] * data_y[j]; + } else { + data_out[ind] += data_x[j] * data_y[j]; + } + } +#endif + } +}; + +template +class DotGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* tensor_x = ctx.Input("X"); + auto* tensor_y = ctx.Input("Y"); + auto* tensor_dout = ctx.Input(framework::GradVarName("Out")); + auto* tensor_dx = ctx.Output(framework::GradVarName("X")); + auto* tensor_dy = ctx.Output(framework::GradVarName("Y")); + + if (tensor_dx) tensor_dx->mutable_data(ctx.GetPlace()); + if (tensor_dy) tensor_dy->mutable_data(ctx.GetPlace()); +#ifdef __NVCC__ + if (1 == tensor_dout->dims().size()) { + auto dout = framework::EigenVector::Flatten(*tensor_dout); + + if (tensor_dx) { + auto y = framework::EigenVector::Flatten(*tensor_y); + auto dx = framework::EigenVector::Flatten(*tensor_dx); + auto& dev = + *ctx.template device_context().eigen_device(); + Eigen::DSizes size(tensor_dx->numel()); + dx.device(dev) = y * dout.broadcast(size); + } + + if (tensor_dy) { + auto x = framework::EigenVector::Flatten(*tensor_x); + auto dy = framework::EigenVector::Flatten(*tensor_dy); + auto& dev = + *ctx.template device_context().eigen_device(); + Eigen::DSizes size(tensor_dy->numel()); + dy.device(dev) = x * dout.broadcast(size); + } + } else { + auto dout = EigenMatrix::From(*tensor_dout); + + if (tensor_dx) { + tensor_dx->mutable_data(ctx.GetPlace()); + auto y = EigenMatrix::From(*tensor_y); + auto dx = EigenMatrix::From(*tensor_dx); + auto& dev = + *ctx.template device_context().eigen_device(); + Eigen::DSizes size(1, tensor_dx->dims()[1]); + dx.device(dev) = y * dout.broadcast(size); + } + + if (tensor_dy) { + tensor_dy->mutable_data(ctx.GetPlace()); + auto x = EigenMatrix::From(*tensor_x); + auto dy = EigenMatrix::From(*tensor_dy); + auto& dev = + *ctx.template device_context().eigen_device(); + Eigen::DSizes size(1, tensor_dy->dims()[1]); + dy.device(dev) = x * dout.broadcast(size); + } + } +#else + const auto* data_dout = tensor_dout->data(); + + if (tensor_dx) { + auto* data_dx = tensor_dx->mutable_data(ctx.GetPlace()); + const auto* data_y = tensor_y->data(); + const framework::DDim& dim = tensor_x->dims(); + size_t N = static_cast(framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dx[i] = data_y[i] * data_dout[s]; + } + } + + if (tensor_dy) { + auto* data_dy = tensor_dy->mutable_data(ctx.GetPlace()); + const auto* data_x = tensor_x->data(); + const framework::DDim& dim = tensor_y->dims(); + size_t N = static_cast(framework::product(dim)); + + auto step = dim[dim.size() - 1]; + + int s = -1; + for (size_t i = 0; i < N; ++i) { + if (0 == i % step) ++s; + data_dy[i] = data_x[i] * data_dout[s]; + } + } +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 32c4c68168d..5e28e698be2 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -148,7 +148,7 @@ from .tensor.math import addmm #DEFINE_ALIAS # from .tensor.io import save #DEFINE_ALIAS # from .tensor.io import load #DEFINE_ALIAS from .tensor.linalg import matmul #DEFINE_ALIAS -# from .tensor.linalg import dot #DEFINE_ALIAS +from .tensor.linalg import dot #DEFINE_ALIAS # from .tensor.linalg import einsum #DEFINE_ALIAS # from .tensor.linalg import morm #DEFINE_ALIAS # from .tensor.linalg import transpose #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_dot_op.py b/python/paddle/fluid/tests/unittests/test_dot_op.py new file mode 100644 index 00000000000..d95f818a62b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dot_op.py @@ -0,0 +1,105 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import paddle +import paddle.fluid as fluid +import unittest +import numpy as np +from op_test import OpTest, skip_check_grad_ci +from paddle.fluid.op import Operator +from paddle.fluid import compiler, Program, program_guard + + +class DotOp(OpTest): + def setUp(self): + self.op_type = "dot" + self.init_dtype() + self.init_input_output() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.outputs = {'Out': self.out} + self.attrs = {} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + self.check_grad(['Y'], 'Out', no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad(['X'], 'Out', no_grad_set=set('Y')) + + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [121]).astype(self.dtype) + self.y = np.random.uniform(1, 3, [121]).astype(self.dtype) + self.out = np.dot(self.x, self.y) + + def init_dtype(self): + self.dtype = np.float64 + + +class DotOpBatch(DotOp): + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [132]).astype(self.dtype).reshape( + [11, 12]) + self.y = np.random.uniform(1, 3, [132]).astype(self.dtype).reshape( + [11, 12]) + self.out = np.sum(self.x * self.y, axis=1).reshape([11, 1]) + + +class TestDotOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + # the input dtype of elementwise_mul must be float16 or float32 or float64 or int32 or int64 + # float16 only can be set on GPU place + x1 = fluid.layers.data(name='x1', shape=[120], dtype="uint8") + y1 = fluid.layers.data(name='y1', shape=[120], dtype="uint8") + self.assertRaises(Exception, paddle.dot, x1, y1) + + x2 = fluid.layers.data(name='x2', shape=[2, 3], dtype="float32") + y2 = fluid.layers.data(name='y2', shape=[2, 3], dtype="float32") + self.assertRaises(Exception, paddle.dot, x2, y2) + + x3 = fluid.layers.data(name='x3', shape=[3], dtype="float32") + y3 = fluid.layers.data(name='y3', shape=[2, 3], dtype="float32") + self.assertRaises(Exception, paddle.dot, x2, y3) + + +class TestDygraph(unittest.TestCase): + def test_dygraph(self): + with fluid.dygraph.guard(): + x1 = fluid.dygraph.to_variable(np.array([1, 3]).astype(np.float32)) + y1 = fluid.dygraph.to_variable(np.array([2, 5]).astype(np.float32)) + self.assertTrue( + np.allclose(paddle.dot(x1, y1).numpy(), np.array([17]))) + + x1 = fluid.dygraph.to_variable( + np.array([[1, 3], [3, 5]]).astype(np.float32)) + y1 = fluid.dygraph.to_variable( + np.array([[2, 5], [6, 8]]).astype(np.float32)) + self.assertTrue( + np.array_equal( + paddle.dot(x1, y1).numpy(), np.array([[17], [58]]))) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py index 3f6978d48aa..eb1471e377a 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py @@ -34,6 +34,7 @@ NEED_TO_FIX_OP_LIST = [ 'deformable_conv_v1', 'depthwise_conv2d', 'depthwise_conv2d_transpose', + 'dot', 'elementwise_add', 'elementwise_div', 'elementwise_max', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 40476b49a37..4748172a00c 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -123,7 +123,7 @@ from .math import addmm #DEFINE_ALIAS # from .io import save #DEFINE_ALIAS # from .io import load #DEFINE_ALIAS from .linalg import matmul #DEFINE_ALIAS -# from .linalg import dot #DEFINE_ALIAS +from .linalg import dot #DEFINE_ALIAS # from .linalg import einsum #DEFINE_ALIAS # from .linalg import morm #DEFINE_ALIAS # from .linalg import transpose #DEFINE_ALIAS @@ -131,7 +131,6 @@ from .linalg import dist #DEFINE_ALIAS # from .linalg import t #DEFINE_ALIAS # from .linalg import cross #DEFINE_ALIAS # from .linalg import cholesky #DEFINE_ALIAS -# from .linalg import dot #DEFINE_ALIAS # from .manipulation import cast #DEFINE_ALIAS # from .manipulation import concat #DEFINE_ALIAS # from .manipulation import expand #DEFINE_ALIAS diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 7baba355180..4d0d99edf42 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -16,10 +16,9 @@ from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_variable_and_dtype, check_type from ..fluid.framework import in_dygraph_mode -# TODO: define functions of linear algebra __all__ = [ 'matmul', - # 'dot', + 'dot', # 'einsum', # 'morm', # 'transpose', @@ -234,3 +233,52 @@ def dist(x, y, p=2): helper.append_op( type='dist', inputs=inputs, outputs={'Out': out}, attrs=attrs) return out + + +def dot(x, y, name=None): + """ + This operator calculates inner product for vectors. + + .. note:: + Only support 1-d Tensor(vector). + + Parameters: + + x(Variable): 1-D ``Tensor`` or ``LoDTensor``. Its datatype should be ``float32``, ``float64``, ``int32``, ``int64`` + y(Variable): 1-D ``Tensor`` or ``LoDTensor``. Its datatype soulde be ``float32``, ``float64``, ``int32``, ``int64`` + name(str, optional): Name of the output. Default is None. It's used to print debug info for developers. Details: :ref:`api_guide_Name` + + Examples: + + .. code-block:: python + + import paddle + import paddle.fluid as fluid + import numpy as np + + with fluid.dygraph.guard(): + x = fluid.dygraph.to_variable(np.random.uniform(0.1, 1, [10]).astype(np.float32)) + y = fluid.dygraph.to_variable(np.random.uniform(1, 3, [10]).astype(np.float32)) + z = paddle.dot(x, y) + print(z.numpy()) + + """ + op_type = 'dot' + assert x is not None, 'x cannot be None in {}'.format(op_type) + assert y is not None, 'y cannot be None in {}'.format(op_type) + + check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'], + op_type) + check_variable_and_dtype(y, 'y', ['float32', 'float64', 'int32', 'int64'], + op_type) + + helper = LayerHelper(op_type, **locals()) + if name is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + else: + out = helper.create_variable( + name=name, dtype=x.dtype, persistable=False) + helper.append_op( + type="dot", inputs={'X': x, + 'Y': y}, attrs={}, outputs={"Out": out}) + return out -- GitLab