提交 e81773c9 编写于 作者: P phlrain

move reset impl to phi; test=develop

上级 ec0e8391
...@@ -114,12 +114,95 @@ void EmbeddingGradKernel(const Context& ctx, ...@@ -114,12 +114,95 @@ void EmbeddingGradKernel(const Context& ctx,
paddle::framework::TransToProtoVarType(input.dtype()), functor); paddle::framework::TransToProtoVarType(input.dtype()), functor);
} }
template <typename T, typename Context>
struct LookupTableV2SparseGradCPUFunctor {
LookupTableV2SparseGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_grad_(out_grad),
weight_grad_(weight_grad),
padding_idx_(padding_idx) {}
template <typename IdT>
void apply() {
DDim table_dim = weight_.dims();
auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_num = static_cast<int64_t>(ids.size());
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
auto* d_table = weight_grad_;
auto* d_output = &out_grad_;
d_table->set_rows(ids);
auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]});
d_table_value->template mutable_data<T>(dev_ctx_.GetPlace());
d_table->set_height(table_dim[0]);
auto* d_output_data = d_output->template data<T>();
auto* d_table_data = d_table_value->template data<T>();
auto d_output_dims = d_output->dims();
auto d_output_dims_2d =
flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
PADDLE_ENFORCE_EQ(d_table_value->dims(),
d_output_dims_2d,
phi::errors::InvalidArgument(
"ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s].",
d_table_value->dims(),
d_output_dims_2d));
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
}
private:
const Context& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& weight_;
const DenseTensor& out_grad_;
SelectedRows* weight_grad_;
int64_t padding_idx_;
};
template <typename T, typename Context>
void EmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad) {
LookupTableV2SparseGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
}
} // namespace phi } // namespace phi
PT_REGISTER_KERNEL(embedding_grad, PD_REGISTER_KERNEL(embedding_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::EmbeddingGradKernel, phi::EmbeddingGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {}
PD_REGISTER_KERNEL(embedding_sparse_grad,
CPU,
ALL_LAYOUT,
phi::EmbeddingSparseGradKernel,
float,
double,
phi::dtype::float16) {}
...@@ -99,7 +99,7 @@ void EmbeddingKernel(const Context& ctx, ...@@ -99,7 +99,7 @@ void EmbeddingKernel(const Context& ctx,
} // namespace phi } // namespace phi
PT_REGISTER_KERNEL(embedding, PD_REGISTER_KERNEL(embedding,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::EmbeddingKernel, phi::EmbeddingKernel,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/embedding_grad_kernel.h" #include "paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h" #include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
...@@ -23,13 +23,13 @@ ...@@ -23,13 +23,13 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
struct LookupTableV2GradCPUFunctor { struct SparseWeightLookupTableV2GradCPUFunctor {
LookupTableV2GradCPUFunctor(const Context& dev_ctx, SparseWeightLookupTableV2GradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& input,
const SelectedRows& weight, const SelectedRows& weight,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* weight_grad) DenseTensor* weight_grad)
: dev_ctx_(dev_ctx), : dev_ctx_(dev_ctx),
input_(input), input_(input),
weight_(weight), weight_(weight),
...@@ -101,6 +101,68 @@ struct LookupTableV2GradCPUFunctor { ...@@ -101,6 +101,68 @@ struct LookupTableV2GradCPUFunctor {
int64_t padding_idx_; int64_t padding_idx_;
}; };
template <typename T, typename Context>
struct SparseWeightLookupTableV2SparseGradCPUFunctor {
SparseWeightLookupTableV2SparseGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const SelectedRows& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_grad_(out_grad),
weight_grad_(weight_grad),
padding_idx_(padding_idx) {}
template <typename IdT>
void apply() {
DDim table_dim = weight_.dims();
auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_num = static_cast<int64_t>(ids.size());
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
auto* d_table = weight_grad_;
auto* d_output = &out_grad_;
d_table->set_rows(ids);
auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]});
d_table_value->template mutable_data<T>(dev_ctx_.GetPlace());
d_table->set_height(table_dim[0]);
auto* d_output_data = d_output->template data<T>();
auto* d_table_data = d_table_value->template data<T>();
auto d_output_dims = d_output->dims();
auto d_output_dims_2d =
phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
PADDLE_ENFORCE_EQ(d_table_value->dims(),
d_output_dims_2d,
phi::errors::InvalidArgument(
"ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s].",
d_table_value->dims(),
d_output_dims_2d));
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
}
private:
const Context& dev_ctx_;
const DenseTensor& input_;
const SelectedRows& weight_;
const DenseTensor& out_grad_;
SelectedRows* weight_grad_;
int64_t padding_idx_;
};
template <typename T, typename Context> template <typename T, typename Context>
void SparseWeightEmbeddingGradKernel(const Context& ctx, void SparseWeightEmbeddingGradKernel(const Context& ctx,
const DenseTensor& input, const DenseTensor& input,
...@@ -108,7 +170,20 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx, ...@@ -108,7 +170,20 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* weight_grad) { DenseTensor* weight_grad) {
LookupTableV2GradCPUFunctor<T, Context> functor( SparseWeightLookupTableV2GradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
}
template <typename T, typename Context>
void SparseWeightEmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& input,
const SelectedRows& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad) {
SparseWeightLookupTableV2SparseGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad); ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType( paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor); paddle::framework::TransToProtoVarType(input.dtype()), functor);
...@@ -116,10 +191,18 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx, ...@@ -116,10 +191,18 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx,
} // namespace phi } // namespace phi
PT_REGISTER_KERNEL(sparse_weight_embedding_grad, PD_REGISTER_KERNEL(sparse_weight_embedding_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::SparseWeightEmbeddingGradKernel, phi::SparseWeightEmbeddingGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {}
PD_REGISTER_KERNEL(sparse_weight_embedding_sparse_grad,
CPU,
ALL_LAYOUT,
phi::SparseWeightEmbeddingSparseGradKernel,
float,
double,
phi::dtype::float16) {}
...@@ -24,12 +24,12 @@ ...@@ -24,12 +24,12 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
struct LookupTableV2CPUFunctor { struct LookupTableV2CPUSparseFunctor {
LookupTableV2CPUFunctor(const Context& dev_ctx, LookupTableV2CPUSparseFunctor(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& input,
const SelectedRows& weight, const SelectedRows& weight,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* out) DenseTensor* out)
: dev_ctx_(dev_ctx), : dev_ctx_(dev_ctx),
input_(input), input_(input),
weight_(weight), weight_(weight),
...@@ -94,7 +94,7 @@ void SparseWeightEmbeddingKernel(const Context& ctx, ...@@ -94,7 +94,7 @@ void SparseWeightEmbeddingKernel(const Context& ctx,
const SelectedRows& weight, const SelectedRows& weight,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* out) { DenseTensor* out) {
LookupTableV2CPUFunctor<T, Context> functor( LookupTableV2CPUSparseFunctor<T, Context> functor(
ctx, input, weight, padding_idx, out); ctx, input, weight, padding_idx, out);
paddle::framework::VisitIntDataType( paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor); paddle::framework::TransToProtoVarType(input.dtype()), functor);
...@@ -102,7 +102,7 @@ void SparseWeightEmbeddingKernel(const Context& ctx, ...@@ -102,7 +102,7 @@ void SparseWeightEmbeddingKernel(const Context& ctx,
} // namespace phi } // namespace phi
PT_REGISTER_KERNEL(sparse_weight_embedding, PD_REGISTER_KERNEL(sparse_weight_embedding,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::SparseWeightEmbeddingKernel, phi::SparseWeightEmbeddingKernel,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi { namespace phi {
...@@ -26,4 +27,12 @@ void EmbeddingGradKernel(const Context& ctx, ...@@ -26,4 +27,12 @@ void EmbeddingGradKernel(const Context& ctx,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* weight_grad); DenseTensor* weight_grad);
template <typename T, typename Context>
void EmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad);
} // namespace phi } // namespace phi
...@@ -21,7 +21,9 @@ ...@@ -21,7 +21,9 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace phi { namespace phi {
template <typename InT, typename OutT> template <typename InT, typename OutT>
...@@ -120,12 +122,117 @@ void EmbeddingGradKernel(const Context& ctx, ...@@ -120,12 +122,117 @@ void EmbeddingGradKernel(const Context& ctx,
paddle::framework::VisitIntDataType( paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor); paddle::framework::TransToProtoVarType(input.dtype()), functor);
} }
template <typename T, typename Context>
struct LookupTableV2SparseGradCUDAFunctor {
LookupTableV2SparseGradCUDAFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_grad_(out_grad),
padding_idx_(padding_idx),
weight_grad_(weight_grad) {}
template <typename IdT>
void apply() {
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
const auto* ids_data = input_.template data<IdT>();
auto* d_table = weight_grad_;
auto* table = &weight_;
auto* d_output = &out_grad_;
int64_t ids_num = input_.numel();
dim3 threads(128, 8);
dim3 grids(8, 1);
auto stream = dev_ctx_.stream();
paddle::framework::Vector<int64_t> new_rows;
new_rows.resize(ids_num);
auto gpu_place = dev_ctx_.GetPlace();
paddle::framework::MixVector<int64_t> mixv_new_rows(&new_rows);
if (!std::is_same<IdT, int64_t>::value) {
InputTypeConvert<<<grids, threads, 0, stream>>>(
ids_data, ids_num, mixv_new_rows.MutableData(gpu_place));
} else {
paddle::memory::Copy(gpu_place,
mixv_new_rows.CUDAMutableData(gpu_place),
gpu_place,
ids_data,
ids_num * sizeof(int64_t),
stream);
}
mixv_new_rows.CopyToCPU();
d_table->set_rows(new_rows);
auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table->dims()[1]});
d_table_value->template mutable_data<T>(gpu_place);
auto* d_table_data = d_table_value->template data<T>();
auto* d_output_data = d_output->template data<T>();
auto d_output_dims = d_output->dims();
auto d_output_dims_2d =
phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1);
PADDLE_ENFORCE_EQ(d_table_value->dims(),
d_output_dims_2d,
phi::errors::InvalidArgument(
"ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s].",
d_table_value->dims(),
d_output_dims_2d));
paddle::memory::Copy(gpu_place,
d_table_data,
gpu_place,
d_output_data,
d_output->numel() * sizeof(T),
stream);
}
private:
const phi::GPUContext& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& weight_;
const DenseTensor& out_grad_;
int64_t padding_idx_;
SelectedRows* weight_grad_;
};
template <typename T, typename Context>
void EmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad) {
LookupTableV2SparseGradCUDAFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
}
} // namespace phi } // namespace phi
PT_REGISTER_KERNEL(embedding_grad, PD_REGISTER_KERNEL(embedding_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::EmbeddingGradKernel, phi::EmbeddingGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {}
PD_REGISTER_KERNEL(embedding_sparse_grad,
GPU,
ALL_LAYOUT,
phi::EmbeddingSparseGradKernel,
float,
double,
phi::dtype::float16) {}
...@@ -115,7 +115,7 @@ void EmbeddingKernel(const Context &ctx, ...@@ -115,7 +115,7 @@ void EmbeddingKernel(const Context &ctx,
} // namespace phi } // namespace phi
PT_REGISTER_KERNEL(embedding, PD_REGISTER_KERNEL(embedding,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::EmbeddingKernel, phi::EmbeddingKernel,
......
...@@ -27,4 +27,12 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx, ...@@ -27,4 +27,12 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* weight_grad); DenseTensor* weight_grad);
template <typename T, typename Context>
void SparseWeightEmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& input,
const SelectedRows& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad);
} // namespace phi } // namespace phi
...@@ -18,10 +18,8 @@ namespace phi { ...@@ -18,10 +18,8 @@ namespace phi {
KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("W")) { if (ctx.IsDenseTensorInput("W")) {
LOG(ERROR) << "is dense here";
return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"});
} else { } else {
LOG(ERROR) << "is selcted rows";
return KernelSignature( return KernelSignature(
"sparse_weight_embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); "sparse_weight_embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"});
} }
...@@ -30,23 +28,37 @@ KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -30,23 +28,37 @@ KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature EmbeddingGradOpArgumentMapping( KernelSignature EmbeddingGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("W")) { if (ctx.IsDenseTensorInput("W")) {
return KernelSignature("embedding_grad", if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) {
{"Ids", "W", GradVarName("Out")}, return KernelSignature("embedding_sparse_grad",
{"padding_idx"}, {"Ids", "W", GradVarName("Out")},
{GradVarName("W")}); {"padding_idx"},
{GradVarName("W")});
} else {
return KernelSignature("embedding_grad",
{"Ids", "W", GradVarName("Out")},
{"padding_idx"},
{GradVarName("W")});
}
} else { } else {
return KernelSignature("sparse_weight_embedding_grad", if ((paddle::any_cast<bool>(ctx.Attr("is_sparse"))) == true) {
{"Ids", "W", GradVarName("Out")}, return KernelSignature("sparse_weight_embedding_sparse_grad",
{"padding_idx"}, {"Ids", "W", GradVarName("Out")},
{GradVarName("W")}); {"padding_idx"},
{GradVarName("W")});
} else {
return KernelSignature("sparse_weight_embedding_grad",
{"Ids", "W", GradVarName("Out")},
{"padding_idx"},
{GradVarName("W")});
}
} }
} }
} // namespace phi } // namespace phi
PT_REGISTER_BASE_KERNEL_NAME(lookup_table_v2, embedding); PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2, embedding);
PT_REGISTER_BASE_KERNEL_NAME(lookup_table_v2_grad, embedding_grad); PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2_grad, embedding_grad);
PT_REGISTER_ARG_MAPPING_FN(lookup_table_v2, phi::EmbeddingOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(lookup_table_v2, phi::EmbeddingOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(lookup_table_v2_grad, PD_REGISTER_ARG_MAPPING_FN(lookup_table_v2_grad,
phi::EmbeddingGradOpArgumentMapping); phi::EmbeddingGradOpArgumentMapping);
...@@ -25,23 +25,24 @@ import paddle.compat as cpt ...@@ -25,23 +25,24 @@ import paddle.compat as cpt
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
# class TestStaticGraphSupportMultipleInt(unittest.TestCase):
# def test_main(self): class TestStaticGraphSupportMultipleInt(unittest.TestCase):
# dtypes = ['uint8', 'int8', 'int16', 'int32', 'int64'] def test_main(self):
# if paddle.in_dynamic_mode(): dtypes = ['uint8', 'int8', 'int16', 'int32', 'int64']
# paddle.enable_static() if paddle.in_dynamic_mode():
# disable_static = True paddle.enable_static()
# else: disable_static = True
# disable_static = False else:
# for i, dtype in enumerate(dtypes): disable_static = False
# with paddle.static.program_guard(paddle.static.Program(), for i, dtype in enumerate(dtypes):
# paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program(),
# x = paddle.static.data(name='x', shape=[-1, 7, 30], dtype=dtype) paddle.static.Program()):
# emb = paddle.nn.Embedding(10, 20) x = paddle.static.data(name='x', shape=[-1, 7, 30], dtype=dtype)
# y = emb(x) emb = paddle.nn.Embedding(10, 20)
y = emb(x)
# if disable_static:
# paddle.disable_static() if disable_static:
paddle.disable_static()
class TestLookupTableOp(OpTest): class TestLookupTableOp(OpTest):
...@@ -62,17 +63,19 @@ class TestLookupTableOp(OpTest): ...@@ -62,17 +63,19 @@ class TestLookupTableOp(OpTest):
self.check_grad(['W'], 'Out', no_grad_set=set('Ids')) self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
# class TestLookupTableOpInt16(OpTest): class TestLookupTableOpInt16(OpTest):
# def id_dtype(self): def id_dtype(self):
# return "int16" return "int16"
# class TestLookupTableOpInt8(OpTest):
# def id_dtype(self):
# return "int8"
# class TestLookupTableOpUInt8(OpTest): class TestLookupTableOpInt8(OpTest):
# def id_dtype(self): def id_dtype(self):
# return "uint8" return "int8"
class TestLookupTableOpUInt8(OpTest):
def id_dtype(self):
return "uint8"
class TestLookupTableOpWithTensorIds(OpTest): class TestLookupTableOpWithTensorIds(OpTest):
...@@ -90,183 +93,190 @@ class TestLookupTableOpWithTensorIds(OpTest): ...@@ -90,183 +93,190 @@ class TestLookupTableOpWithTensorIds(OpTest):
self.check_grad(['W'], 'Out', no_grad_set=set('Ids')) self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
# @skip_check_grad_ci( @skip_check_grad_ci(
# reason="Since paddings are not trainable and fixed in forward," reason="Since paddings are not trainable and fixed in forward,"
# "the gradient of paddings makes no sense and we don't " "the gradient of paddings makes no sense and we don't "
# "test the gradient here.") "test the gradient here.")
# class TestLookupTableOpWithPadding(TestLookupTableOp): class TestLookupTableOpWithPadding(TestLookupTableOp):
# def test_check_output(self): def test_check_output(self):
# ids = np.squeeze(self.inputs['Ids']) ids = np.squeeze(self.inputs['Ids'])
# padding_idx = np.random.choice(ids, 1)[0] padding_idx = np.random.choice(ids, 1)[0]
# self.outputs['Out'][ids == padding_idx] = np.zeros(31) self.outputs['Out'][ids == padding_idx] = np.zeros(31)
# self.attrs = {'padding_idx': int(padding_idx)} self.attrs = {'padding_idx': int(padding_idx)}
# self.check_output() self.check_output()
# @skip_check_grad_ci(
# reason="Since paddings are not trainable and fixed in forward," @skip_check_grad_ci(
# "the gradient of paddings makes no sense and we don't " reason="Since paddings are not trainable and fixed in forward,"
# "test the gradient here.") "the gradient of paddings makes no sense and we don't "
# class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds): "test the gradient here.")
# def test_check_output(self): class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
# ids = self.inputs['Ids'] def test_check_output(self):
# flatten_idx = ids.flatten() ids = self.inputs['Ids']
# padding_idx = np.random.choice(flatten_idx, 1)[0] flatten_idx = ids.flatten()
# self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31) padding_idx = np.random.choice(flatten_idx, 1)[0]
# self.attrs = {'padding_idx': cpt.long_type(padding_idx)} self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
# self.check_output() self.attrs = {'padding_idx': cpt.long_type(padding_idx)}
self.check_output()
# class TestLookupTableWIsSelectedRows(unittest.TestCase):
# def prepare_ids(self, scope, place):
# ids_tensor = scope.var('Ids').get_tensor() class TestLookupTableWIsSelectedRows(unittest.TestCase):
# ids_array = np.array([0, 4, 3, 5]).astype("int32") def prepare_ids(self, scope, place):
# ids_tensor.set(ids_array, place) ids_tensor = scope.var('Ids').get_tensor()
# return ids_array ids_array = np.array([0, 4, 3, 5]).astype("int32")
ids_tensor.set(ids_array, place)
# def prepare_w(self, scope, place): return ids_array
# rows = [0, 1, 2, 3, 4, 5, 6]
# row_numel = 12 def prepare_w(self, scope, place):
rows = [0, 1, 2, 3, 4, 5, 6]
# w_selected_rows = scope.var('W').get_selected_rows() row_numel = 12
# w_selected_rows.set_height(len(rows))
# w_selected_rows.set_rows(rows) w_selected_rows = scope.var('W').get_selected_rows()
# w_array = np.ones((len(rows), row_numel)).astype("float32") w_selected_rows.set_height(len(rows))
# for i in range(len(rows)): w_selected_rows.set_rows(rows)
# w_array[i] *= i w_array = np.ones((len(rows), row_numel)).astype("float32")
# w_tensor = w_selected_rows.get_tensor() for i in range(len(rows)):
# w_tensor.set(w_array, place) w_array[i] *= i
w_tensor = w_selected_rows.get_tensor()
# def create_out_tensor(self, scope, place): w_tensor.set(w_array, place)
# return scope.var('Out').get_tensor()
def create_out_tensor(self, scope, place):
# def check_result(self, ids_array, result_array): return scope.var('Out').get_tensor()
# # all(): return True if all elements of the iterable are true (or if the iterable is empty)
# for idx, row in enumerate(ids_array): def check_result(self, ids_array, result_array):
# assert (row == result_array[idx]).all() # all(): return True if all elements of the iterable are true (or if the iterable is empty)
for idx, row in enumerate(ids_array):
# def check_with_place(self, place): assert (row == result_array[idx]).all()
# scope = core.Scope()
def check_with_place(self, place):
# ids_array = self.prepare_ids(scope, place) scope = core.Scope()
# self.prepare_w(scope, place) ids_array = self.prepare_ids(scope, place)
# out_tensor = self.create_out_tensor(scope, place) self.prepare_w(scope, place)
# # create and run lookup_table operator out_tensor = self.create_out_tensor(scope, place)
# lookup_table = Operator("lookup_table_v2", W='W', Ids='Ids', Out='Out')
# lookup_table.run(scope, place) # create and run lookup_table operator
lookup_table = Operator("lookup_table_v2", W='W', Ids='Ids', Out='Out')
# # get result from Out lookup_table.run(scope, place)
# result_array = np.array(out_tensor)
# get result from Out
# self.check_result(ids_array, result_array) result_array = np.array(out_tensor)
# def test_w_is_selected_rows(self): self.check_result(ids_array, result_array)
# places = [core.CPUPlace()]
# # currently only support CPU def test_w_is_selected_rows(self):
# for place in places: places = [core.CPUPlace()]
# self.check_with_place(place) # currently only support CPU
for place in places:
# class TestLookupTableWithTensorIdsWIsSelectedRows( self.check_with_place(place)
# TestLookupTableWIsSelectedRows):
# def prepare_ids(self, scope, place):
# ids_tensor = scope.var('Ids').get_tensor() class TestLookupTableWithTensorIdsWIsSelectedRows(
# ids_array = np.random.randint( TestLookupTableWIsSelectedRows):
# low=0, high=6, size=(2, 4, 3)).astype("int64") def prepare_ids(self, scope, place):
# ids_tensor.set(ids_array, place) ids_tensor = scope.var('Ids').get_tensor()
# return ids_array ids_array = np.random.randint(
low=0, high=6, size=(2, 4, 3)).astype("int64")
# def check_result(self, ids_array, result_array): ids_tensor.set(ids_array, place)
# for idx, row in np.ndenumerate(ids_array): return ids_array
# assert (row == result_array[idx]).all()
def check_result(self, ids_array, result_array):
# class TestLookupTableIsSparse(unittest.TestCase): for idx, row in np.ndenumerate(ids_array):
# def init_data(self): assert (row == result_array[idx]).all()
# self.x_data = np.array([[1, 3, 0, 4, 7]]).astype("int64")
# self.y_data = np.array([[0.1, 0.3, 0, 0.4, 0.7]]).astype("float32")
class TestLookupTableIsSparse(unittest.TestCase):
# def get_w_grad(self, is_sparse): def init_data(self):
# self.init_data() self.x_data = np.array([[1, 3, 0, 4, 7]]).astype("int64")
# main_program = fluid.Program() self.y_data = np.array([[0.1, 0.3, 0, 0.4, 0.7]]).astype("float32")
# with fluid.program_guard(main_program, fluid.Program()):
# x = fluid.layers.data(name='x', shape=[5], dtype='int64') def get_w_grad(self, is_sparse):
# y_ = fluid.layers.data(name='y_', shape=[5], dtype='float32') self.init_data()
# emb = fluid.input.embedding( main_program = fluid.Program()
# input=x, with fluid.program_guard(main_program, fluid.Program()):
# size=[10, 16], x = fluid.layers.data(name='x', shape=[5], dtype='int64')
# param_attr=fluid.ParamAttr( y_ = fluid.layers.data(name='y_', shape=[5], dtype='float32')
# name="emb_weight", emb = fluid.input.embedding(
# learning_rate=10, input=x,
# initializer=fluid.initializer.NumpyArrayInitializer( size=[10, 16],
# self.w_data)), param_attr=fluid.ParamAttr(
# is_sparse=is_sparse) name="emb_weight",
# y = fluid.layers.reduce_sum(emb, dim=-1) learning_rate=10,
initializer=fluid.initializer.NumpyArrayInitializer(
# loss = fluid.layers.square_error_cost(input=y, label=y_) self.w_data)),
# loss = fluid.layers.mean(loss) is_sparse=is_sparse)
y = fluid.layers.reduce_sum(emb, dim=-1)
# sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-4)
# sgd_optimizer.minimize(loss) loss = fluid.layers.square_error_cost(input=y, label=y_)
loss = fluid.layers.mean(loss)
# place = fluid.CPUPlace()
# exe = fluid.Executor(place) sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-4)
# exe.run(fluid.default_startup_program()) sgd_optimizer.minimize(loss)
# ret = exe.run(feed={'x': self.x_data,
# 'y_': self.y_data}, place = fluid.CPUPlace()
# fetch_list=['emb_weight'], exe = fluid.Executor(place)
# return_numpy=False) exe.run(fluid.default_startup_program())
# return np.array(ret[0]) ret = exe.run(feed={'x': self.x_data,
'y_': self.y_data},
# def test_w_grad(self): fetch_list=['emb_weight'],
# self.w_data = np.random.random(size=(10, 16)).astype("float32") return_numpy=False)
# w_grad = self.get_w_grad(False) return np.array(ret[0])
# w_grad_with_sparse = self.get_w_grad(True)
# self.check_grad(w_grad, w_grad_with_sparse) def test_w_grad(self):
self.w_data = np.random.random(size=(10, 16)).astype("float32")
# def check_grad(self, w_grad1, w_grad2, tolerance=1e-6): w_grad = self.get_w_grad(False)
# np.testing.assert_allclose( w_grad_with_sparse = self.get_w_grad(True)
# w_grad1, w_grad2, rtol=tolerance, atol=tolerance) self.check_grad(w_grad, w_grad_with_sparse)
# class TestLookupTableApi(unittest.TestCase): def check_grad(self, w_grad1, w_grad2, tolerance=1e-6):
# def test_api(self): np.testing.assert_allclose(
# x = fluid.layers.data(name='x', shape=[20], dtype='int64') w_grad1, w_grad2, rtol=tolerance, atol=tolerance)
# emb = fluid.embedding(input=x, size=[128, 64])
# place = fluid.CPUPlace() class TestLookupTableApi(unittest.TestCase):
# x_data = np.random.randint(0, 127, [2, 20]).astype("int64") def test_api(self):
x = fluid.layers.data(name='x', shape=[20], dtype='int64')
# exe = fluid.Executor(place) emb = fluid.embedding(input=x, size=[128, 64])
# exe.run(fluid.default_startup_program())
# ret = exe.run(feed={'x': x_data, }, place = fluid.CPUPlace()
# fetch_list=[emb], x_data = np.random.randint(0, 127, [2, 20]).astype("int64")
# return_numpy=False)
exe = fluid.Executor(place)
# class TestEmbedOpError(unittest.TestCase): exe.run(fluid.default_startup_program())
# def test_errors(self): ret = exe.run(feed={'x': x_data, },
# with program_guard(Program(), Program()): fetch_list=[emb],
# input_data = np.random.randint(0, 10, (4, 6)).astype("int64") return_numpy=False)
# def test_Variable():
# # the input type must be Variable class TestEmbedOpError(unittest.TestCase):
# fluid.embedding(input=input_data, size=(10, 64)) def test_errors(self):
with program_guard(Program(), Program()):
# self.assertRaises(TypeError, test_Variable) input_data = np.random.randint(0, 10, (4, 6)).astype("int64")
# def test_input_dtype(): def test_Variable():
# # the input dtype must be int64 # the input type must be Variable
# input = fluid.data(name='x1', shape=[4, 6], dtype='float32') fluid.embedding(input=input_data, size=(10, 64))
# fluid.embedding(input=input, size=(10, 64))
self.assertRaises(TypeError, test_Variable)
# self.assertRaises(TypeError, test_input_dtype)
def test_input_dtype():
# def test_param_dtype(): # the input dtype must be int64
# # dtype must be float32 or float64 input = fluid.data(name='x1', shape=[4, 6], dtype='float32')
# input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64') fluid.embedding(input=input, size=(10, 64))
# fluid.embedding(input=input2, size=(10, 64), dtype='int64')
self.assertRaises(TypeError, test_input_dtype)
# self.assertRaises(TypeError, test_param_dtype)
# input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64') def test_param_dtype():
# fluid.embedding(input=input3, size=(10, 64), dtype='float16') # dtype must be float32 or float64
input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64')
fluid.embedding(input=input2, size=(10, 64), dtype='int64')
self.assertRaises(TypeError, test_param_dtype)
input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64')
fluid.embedding(input=input3, size=(10, 64), dtype='float16')
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static() paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册