未验证 提交 2785f876 编写于 作者: Z zyfncg 提交者: GitHub

add embedding yaml (#43029)

* add embedding yaml

* fix infermeta bug

* fix bug of selected_rows infer_meta

* fix selected_rows

* add unittest
上级 b779d2b8
......@@ -638,6 +638,80 @@ Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) {
return out;
}
Tensor embedding_impl(const Tensor& x,
const Tensor& weight,
int64_t padding_idx,
bool sparse) {
DataType kernel_data_type = ParseDataType(weight);
auto kernel_key_set = ParseKernelKeyByInputArgs(weight);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
VLOG(6) << "embedding API kernel key: [" << kernel_key.backend() << ", "
<< kernel_key.layout() << ", " << kernel_data_type << "]";
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
Tensor api_output;
if (phi::DenseTensor::classof(weight.impl().get())) {
const auto& kernel =
phi::KernelFactory::Instance().SelectKernelOrThrowError(
"embedding",
{kernel_key.backend(), kernel_key.layout(), kernel_data_type});
VLOG(6) << "embedding API kernel: " << kernel;
auto input_x = PrepareData(x, kernel.InputAt(0), {});
auto input_weight = PrepareData(weight, kernel.InputAt(1), {});
auto* kernel_out = SetKernelOutput(kernel_key.backend(), &api_output);
phi::MetaTensor meta_out(kernel_out);
phi::EmbeddingInferMeta(MakeMetaTensor(*input_x),
MakeMetaTensor(*input_weight),
padding_idx,
sparse,
&meta_out);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
int64_t,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{
(*kernel_fn)(*dev_ctx, *input_x, *input_weight, padding_idx, kernel_out);
}
} else {
const auto& kernel =
phi::KernelFactory::Instance().SelectKernelOrThrowError(
"sparse_weight_embedding",
{kernel_key.backend(), kernel_key.layout(), kernel_data_type});
VLOG(6) << "sparse_weight_embedding API kernel: " << kernel;
auto input_x = PrepareData(x, kernel.InputAt(0), {});
auto input_weight = TensorToSelectedRows(weight);
auto* kernel_out = SetKernelOutput(kernel_key.backend(), &api_output);
phi::MetaTensor meta_out(kernel_out);
phi::EmbeddingInferMeta(MakeMetaTensor(*input_x),
MakeMetaTensor(*input_weight),
padding_idx,
sparse,
&meta_out);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::SelectedRows&,
int64_t,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{
(*kernel_fn)(*dev_ctx, *input_x, *input_weight, padding_idx, kernel_out);
}
}
return api_output;
}
std::vector<Tensor> split_impl(const Tensor& x,
const IntArray& num_or_sections,
const Scalar& axis) {
......@@ -1176,6 +1250,125 @@ void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad) {
(*kernel_fn)(*dev_ctx, *dense_out_grad, kernel_out);
}
void embedding_grad_impl(const Tensor& x,
const Tensor& weight,
const Tensor& out_grad,
int64_t padding_idx,
bool sparse,
Tensor* weight_grad) {
DataType kernel_data_type = ParseDataType(weight);
auto kernel_key_set = ParseKernelKeyByInputArgs(weight);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
VLOG(6) << "embedding_grad API kernel key: [" << kernel_key.backend() << ", "
<< kernel_key.layout() << ", " << kernel_data_type << "]";
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
if (phi::DenseTensor::classof(weight.impl().get())) {
std::string kernel_name =
sparse ? "embedding_sparse_grad" : "embedding_grad";
const auto& kernel =
phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name,
{kernel_key.backend(), kernel_key.layout(), kernel_data_type});
VLOG(6) << kernel_name << " API kernel: " << kernel;
auto input_x = PrepareData(x, kernel.InputAt(0), {});
auto input_weight = PrepareData(weight, kernel.InputAt(1), {});
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {});
if (sparse) {
auto* kernel_out =
SetSelectedRowsKernelOutput(kernel_key.backend(), weight_grad);
phi::MetaTensor meta_out(kernel_out);
meta_out.set_dims(input_weight->dims());
meta_out.set_dtype(input_weight->dtype());
kernel_out->set_height(input_weight->dims()[0]);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
int64_t,
phi::SelectedRows*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*input_x,
*input_weight,
*input_out_grad,
padding_idx,
kernel_out);
} else {
auto* kernel_out = SetKernelOutput(kernel_key.backend(), weight_grad);
phi::MetaTensor meta_out(kernel_out);
phi::UnchangedInferMeta(MakeMetaTensor(*input_weight), &meta_out);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
int64_t,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*input_x,
*input_weight,
*input_out_grad,
padding_idx,
kernel_out);
}
} else {
std::string kernel_name = sparse ? "sparse_weight_embedding_sparse_grad"
: "sparse_weight_embedding_grad";
const auto& kernel =
phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name,
{kernel_key.backend(), kernel_key.layout(), kernel_data_type});
VLOG(6) << kernel_name << " API kernel: " << kernel;
auto input_x = PrepareData(x, kernel.InputAt(0), {});
auto input_weight = TensorToSelectedRows(weight);
auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {});
if (sparse) {
auto* kernel_out =
SetSelectedRowsKernelOutput(kernel_key.backend(), weight_grad);
phi::MetaTensor meta_out(kernel_out);
phi::UnchangedInferMeta(MakeMetaTensor(*input_weight), &meta_out);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::SelectedRows&,
const phi::DenseTensor&,
int64_t,
phi::SelectedRows*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*input_x,
*input_weight,
*input_out_grad,
padding_idx,
kernel_out);
} else {
auto* kernel_out = SetKernelOutput(kernel_key.backend(), weight_grad);
phi::MetaTensor meta_out(kernel_out);
meta_out.set_dims(input_weight->GetCompleteDims());
meta_out.set_dtype(input_weight->dtype());
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::SelectedRows&,
const phi::DenseTensor&,
int64_t,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*input_x,
*input_weight,
*input_out_grad,
padding_idx,
kernel_out);
}
}
}
void real_grad_impl(const Tensor& out_grad, Tensor* x_grad) {
phi::KernelKey kernel_key{ParseBackend(out_grad),
out_grad.layout(),
......
......@@ -98,6 +98,11 @@ Tensor conv2d_impl(const Tensor& input,
Tensor copy_to_impl(const Tensor& x, Place place, bool blocking);
Tensor embedding_impl(const Tensor& x,
const Tensor& weight,
int64_t padding_idx,
bool sparse);
std::vector<Tensor> split_impl(const Tensor& x,
const IntArray& num_or_sections,
const Scalar& axis);
......@@ -145,6 +150,13 @@ void conv2d_grad_impl(const Tensor& input,
void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad);
void embedding_grad_impl(const Tensor& x,
const Tensor& weight,
const Tensor& out_grad,
int64_t padding_idx,
bool sparse,
Tensor* weight_grad);
void real_grad_impl(const Tensor& out_grad, Tensor* x_grad);
} // namespace experimental
......
......@@ -983,6 +983,32 @@ void ElementwiseRawInferMeta(const MetaTensor& x,
out->share_lod(x);
}
void EmbeddingInferMeta(const MetaTensor& x,
const MetaTensor& weight,
int64_t padding_idx,
bool sparse,
MetaTensor* out) {
const auto& table_dims = weight.dims();
const auto& ids_dims = x.dims();
int ids_rank = ids_dims.size();
VLOG(5) << "ids rank is " << ids_rank << std::endl;
PADDLE_ENFORCE_EQ(
table_dims.size(),
2,
phi::errors::InvalidArgument(
"ShapeError: The dimensions of the 'lookup table' must be 2. "
"But received lookup table's dimensions = %d, "
"lookup table's shape = [%s].",
table_dims.size(),
table_dims));
auto output_dims = phi::vectorize(ids_dims);
output_dims.push_back(table_dims[1]);
out->set_dims(phi::make_ddim(output_dims));
out->set_dtype(weight.dtype());
out->share_lod(x);
}
void ExpandAsInferMeta(const MetaTensor& x,
const MetaTensor& y,
const std::vector<int>& target_shape,
......
......@@ -154,6 +154,12 @@ void ElementwiseRawInferMeta(const MetaTensor& x_meta,
int axis,
MetaTensor* out);
void EmbeddingInferMeta(const MetaTensor& x,
const MetaTensor& weight,
int64_t padding_idx,
bool sparse,
MetaTensor* out);
void ExpandAsInferMeta(const MetaTensor& x,
const MetaTensor& y,
const std::vector<int>& target_shape,
......
......@@ -15,6 +15,7 @@ cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_empty_api SRCS test_empty_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_fill_api SRCS test_fill_api.cc DEPS ${COMMON_API_TEST_DEPS} api_scalar)
cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_embedding_api SRCS test_embedding_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_cast_api SRCS test_cast_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_to_api SRCS test_to_api.cc DEPS ${COMMON_API_TEST_DEPS})
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/phi/api/backward/backward_api.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
PD_DECLARE_KERNEL(sparse_weight_embedding, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sparse_weight_embedding_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sparse_weight_embedding_sparse_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(empty, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
namespace paddle {
namespace tests {
TEST(API, sparse_weight_embedding) {
auto x = paddle::experimental::empty({4}, DataType::INT32);
auto* x_data = x.data<int32_t>();
x_data[0] = 0;
x_data[1] = 4;
x_data[2] = 3;
x_data[3] = 1;
auto weight_sr = std::make_shared<phi::SelectedRows>(
std::vector<int64_t>{0, 1, 2, 3, 4, 5, 6}, 16);
*weight_sr->mutable_value() = *static_cast<phi::DenseTensor*>(
paddle::experimental::full({7, 3}, 2, DataType::FLOAT32).impl().get());
paddle::experimental::Tensor weight;
weight.set_impl(weight_sr);
auto out = paddle::experimental::embedding(x, weight);
// 3. check result
ASSERT_EQ(out.dims().size(), 2);
ASSERT_EQ(out.dims()[0], 4);
ASSERT_EQ(out.numel(), 12);
ASSERT_EQ(out.type(), phi::DataType::FLOAT32);
ASSERT_EQ(out.layout(), phi::DataLayout::NCHW);
}
TEST(API, sparse_weight_embedding_grad) {
auto x = paddle::experimental::empty({4}, DataType::INT32);
auto* x_data = x.data<int32_t>();
x_data[0] = 0;
x_data[1] = 4;
x_data[2] = 3;
x_data[3] = 1;
auto weight_sr = std::make_shared<phi::SelectedRows>(
std::vector<int64_t>{0, 1, 2, 3, 4, 5, 6}, 16);
*weight_sr->mutable_value() = *static_cast<phi::DenseTensor*>(
paddle::experimental::full({7, 3}, 2, DataType::FLOAT32).impl().get());
paddle::experimental::Tensor weight;
weight.set_impl(weight_sr);
auto out_grad = paddle::experimental::full({4, 3}, 1, DataType::FLOAT32);
paddle::experimental::Tensor weight_grad;
paddle::experimental::embedding_grad(
x, weight, out_grad, -1, false, &weight_grad);
// 3. check result
ASSERT_EQ(weight_grad.dims().size(), 2);
ASSERT_EQ(weight_grad.dims()[0], 16);
ASSERT_EQ(weight_grad.numel(), 48);
ASSERT_EQ(weight_grad.type(), phi::DataType::FLOAT32);
ASSERT_EQ(weight_grad.layout(), phi::DataLayout::NCHW);
}
TEST(API, sparse_weight_embedding_sparse_grad) {
auto x = paddle::experimental::empty({4}, DataType::INT32);
auto* x_data = x.data<int32_t>();
x_data[0] = 0;
x_data[1] = 4;
x_data[2] = 3;
x_data[3] = 1;
auto weight_sr = std::make_shared<phi::SelectedRows>(
std::vector<int64_t>{0, 1, 2, 3, 4, 5, 6}, 16);
*weight_sr->mutable_value() = *static_cast<phi::DenseTensor*>(
paddle::experimental::full({7, 3}, 2, DataType::FLOAT32).impl().get());
paddle::experimental::Tensor weight;
weight.set_impl(weight_sr);
auto out_grad = paddle::experimental::full({4, 3}, 1, DataType::FLOAT32);
paddle::experimental::Tensor weight_grad;
paddle::experimental::embedding_grad(
x, weight, out_grad, -1, true, &weight_grad);
// 3. check result
ASSERT_EQ(weight_grad.dims().size(), 2);
ASSERT_EQ(weight_grad.dims()[0], 4);
ASSERT_EQ(weight_grad.numel(), 12);
ASSERT_EQ(weight_grad.type(), phi::DataType::FLOAT32);
ASSERT_EQ(weight_grad.layout(), phi::DataLayout::NCHW);
}
} // namespace tests
} // namespace paddle
......@@ -48,6 +48,7 @@ class TestStaticGraphSupportMultipleInt(unittest.TestCase):
class TestLookupTableOp(OpTest):
def setUp(self):
self.op_type = "lookup_table_v2"
self.python_api = paddle.nn.functional.embedding
table = np.random.random((17, 31)).astype("float64")
ids = np.random.randint(0, 17, 4).astype(self.id_dtype())
self.inputs = {'W': table, 'Ids': ids}
......@@ -57,10 +58,10 @@ class TestLookupTableOp(OpTest):
return "int64"
def test_check_output(self):
self.check_output()
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
self.check_grad(['W'], 'Out', no_grad_set=set('Ids'), check_eager=True)
class TestLookupTableOpInt16(OpTest):
......
......@@ -200,7 +200,9 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
raise ValueError("padding_idx must be within [-{}, {})".format(
weight.shape[0], weight.shape[0]))
if in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_embedding(x, weight, padding_idx, sparse)
elif _in_legacy_dygraph():
return _C_ops.lookup_table_v2(
weight, x, 'is_sparse', sparse, 'is_distributed', False,
'remote_prefetch', False, 'padding_idx', padding_idx)
......
......@@ -613,6 +613,12 @@
func : elu
backward : elu_grad
- api : embedding
args : (Tensor x, Tensor weight, int64_t padding_idx=-1, bool sparse=false)
output : Tensor
invoke : embedding_impl(x, weight, padding_idx, sparse)
backward : embedding_grad
- api : empty
args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor
......
......@@ -609,6 +609,12 @@
backward : elu_double_grad
inplace : (out_grad -> x_grad)
- backward_api : embedding_grad
forward : embedding (Tensor x, Tensor weight, int64_t padding_idx=-1, bool sparse=false) -> Tensor(out)
args : (Tensor x, Tensor weight, Tensor out_grad, int64_t padding_idx=-1, bool sparse=false)
output : Tensor(weight_grad)
invoke : embedding_grad_impl(x, weight, out_grad, padding_idx, sparse, weight_grad)
- backward_api : erf_grad
forward : erf (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册