提交 1c58eee9 编写于 作者: L luotao1 提交者: ceci3

refine infershape of sequence_enumerate, hash and fuse_emb_seq_pool

test=develop
上级 3f4aeed5
...@@ -23,6 +23,9 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { ...@@ -23,6 +23,9 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE(ctx->HasInput("W"), PADDLE_ENFORCE(ctx->HasInput("W"),
"Input W of FusedEmbeddingSeqPoolOp should not be null."); "Input W of FusedEmbeddingSeqPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"), PADDLE_ENFORCE(ctx->HasInput("Ids"),
...@@ -42,36 +45,15 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { ...@@ -42,36 +45,15 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
// we only support sum now // we only support sum now
PADDLE_ENFORCE_EQ(combiner, "sum"); PADDLE_ENFORCE_EQ(combiner, "sum");
int64_t last_dim = table_dims[1]; int64_t last_dim = FusedEmbeddingSeqPoolLastDim(table_dims, ids_dims);
for (int i = 1; i != ids_dims.size(); ++i) { // in compile time, the lod level of ids must be 1
last_dim *= ids_dims[i]; framework::VarDesc* ids_desc =
} boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1);
if (ctx->IsRuntime()) {
framework::Variable* ids_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Ids")[0]);
const auto& ids_lod = ids_var->Get<LoDTensor>().lod();
// in run time, the LoD of ids must be 1 // in compile time, the shape from Ids -> output
PADDLE_ENFORCE(ids_lod.size(), 1u, // should be [-1, 1] -> [-1, embedding_size]
"The LoD level of Input(Ids) must be 1"); ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim}));
PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty");
int64_t batch_size = ids_lod[0].size() - 1;
// in run time, the shape from Ids -> output
// should be [seq_length, 1] -> [batch_size, embedding_size]
ctx->SetOutputDim("Out", framework::make_ddim({batch_size, last_dim}));
} else {
// in compile time, the lod level of ids must be 1
framework::VarDesc* ids_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1);
// in compile time, the shape from Ids -> output
// should be [-1, 1] -> [-1, embedding_size]
ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim}));
}
} }
protected: protected:
......
...@@ -61,6 +61,15 @@ struct EmbeddingVSumFunctor { ...@@ -61,6 +61,15 @@ struct EmbeddingVSumFunctor {
} }
}; };
inline int FusedEmbeddingSeqPoolLastDim(const framework::DDim &table_dims,
const framework::DDim &ids_dims) {
int64_t last_dim = table_dims[1];
for (int i = 1; i != ids_dims.size(); ++i) {
last_dim *= ids_dims[i];
}
return last_dim;
}
template <typename T> template <typename T>
class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> { class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
public: public:
...@@ -70,6 +79,17 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> { ...@@ -70,6 +79,17 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
const LoDTensor *table_var = context.Input<LoDTensor>("W"); const LoDTensor *table_var = context.Input<LoDTensor>("W");
const std::string &combiner_type = context.Attr<std::string>("combiner"); const std::string &combiner_type = context.Attr<std::string>("combiner");
int64_t last_dim =
FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims());
const auto &ids_lod = ids_t->lod();
// in run time, the LoD of ids must be 1
PADDLE_ENFORCE(ids_lod.size(), 1u, "The LoD level of Input(Ids) must be 1");
PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty");
int64_t batch_size = ids_lod[0].size() - 1;
// in run time, the shape from Ids -> output
// should be [seq_length, 1] -> [batch_size, embedding_size]
output_t->Resize({batch_size, last_dim});
if (combiner_type == "sum") { if (combiner_type == "sum") {
EmbeddingVSumFunctor<T> functor; EmbeddingVSumFunctor<T> functor;
functor(context, table_var, ids_t, output_t); functor(context, table_var, ids_t, output_t);
......
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/hash_op.h" #include "paddle/fluid/operators/hash_op.h"
#include <string> #include <string>
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -27,6 +26,9 @@ class HashOp : public framework::OperatorWithKernel { ...@@ -27,6 +26,9 @@ class HashOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of HashOp should not be null."); "Input(X) of HashOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
...@@ -36,15 +38,8 @@ class HashOp : public framework::OperatorWithKernel { ...@@ -36,15 +38,8 @@ class HashOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(dims.size(), 2UL, PADDLE_ENFORCE_EQ(dims.size(), 2UL,
"The input of hash_op's dimensions must be 2"); "The input of hash_op's dimensions must be 2");
std::vector<int64_t> out_dims; std::vector<int64_t> out_dims;
out_dims.reserve(dims.size() + 1);
// copy all dims except the last one
for (int i = 0u; i != dims.size() - 1; ++i) {
out_dims.emplace_back(dims[i]);
}
int num_hash = ctx->Attrs().Get<int>("num_hash"); int num_hash = ctx->Attrs().Get<int>("num_hash");
out_dims.emplace_back(num_hash); HashOutputSize(dims, out_dims, num_hash);
// keep the last dim to 1
out_dims.emplace_back(1);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
...@@ -71,4 +66,4 @@ $$Out = scale * X$$ ...@@ -71,4 +66,4 @@ $$Out = scale * X$$
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(hash, ops::HashOp, ops::HashOpMaker); REGISTER_OP_WITHOUT_GRADIENT(hash, ops::HashOp, ops::HashOpMaker);
REGISTER_OP_CPU_KERNEL(hash, ops::HashKerel<int>, ops::HashKerel<int64_t>); REGISTER_OP_CPU_KERNEL(hash, ops::HashKernel<int>, ops::HashKernel<int64_t>);
...@@ -17,21 +17,34 @@ limitations under the License. */ ...@@ -17,21 +17,34 @@ limitations under the License. */
extern "C" { extern "C" {
#include <xxhash.h> #include <xxhash.h>
} }
#include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// template <typename DeviceContext, typename T>
inline void HashOutputSize(const framework::DDim& in_dims,
std::vector<int64_t>& out_dims, // NOLINT
int num_hash) {
out_dims.reserve(in_dims.size() + 1);
// copy all dims except the last one
for (int i = 0u; i != in_dims.size() - 1; ++i) {
out_dims.emplace_back(in_dims[i]);
}
out_dims.emplace_back(num_hash);
// keep the last dim to 1
out_dims.emplace_back(1);
}
template <typename T> template <typename T>
class HashKerel : public framework::OpKernel<T> { class HashKernel : public framework::OpKernel<T> {
public: public:
virtual void Compute(const framework::ExecutionContext& context) const { virtual void Compute(const framework::ExecutionContext& context) const {
auto* out_t = context.Output<framework::LoDTensor>("Out"); auto* out_t = context.Output<framework::LoDTensor>("Out");
auto* in_t = context.Input<framework::LoDTensor>("X"); auto* in_t = context.Input<framework::LoDTensor>("X");
int mod_by = context.Attr<int>("mod_by"); int mod_by = context.Attr<int>("mod_by");
int num_hash = context.Attr<int>("num_hash"); int num_hash = context.Attr<int>("num_hash");
auto* output = out_t->mutable_data<T>(context.GetPlace());
auto in_dims = in_t->dims(); auto in_dims = in_t->dims();
auto in_lod = in_t->lod(); auto in_lod = in_t->lod();
...@@ -39,6 +52,11 @@ class HashKerel : public framework::OpKernel<T> { ...@@ -39,6 +52,11 @@ class HashKerel : public framework::OpKernel<T> {
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(), static_cast<uint64_t>(in_dims[0]), in_lod[0].back(),
"The actual input data's size mismatched with LoD information."); "The actual input data's size mismatched with LoD information.");
std::vector<int64_t> out_dims;
HashOutputSize(in_dims, out_dims, num_hash);
out_t->Resize(framework::make_ddim(out_dims));
auto* output = out_t->mutable_data<T>(context.GetPlace());
auto seq_length = in_dims[0]; auto seq_length = in_dims[0];
auto last_dim = in_dims[in_dims.size() - 1]; auto last_dim = in_dims[in_dims.size() - 1];
auto* input = in_t->data<T>(); auto* input = in_t->data<T>();
...@@ -49,6 +67,7 @@ class HashKerel : public framework::OpKernel<T> { ...@@ -49,6 +67,7 @@ class HashKerel : public framework::OpKernel<T> {
} }
input += last_dim; input += last_dim;
} }
out_t->set_lod(in_t->lod());
} }
}; };
......
...@@ -22,6 +22,9 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel { ...@@ -22,6 +22,9 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasInput("X"), ctx->HasInput("X"),
"Input(X) of SequecceEnumerate operator should not be null."); "Input(X) of SequecceEnumerate operator should not be null.");
...@@ -33,9 +36,9 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel { ...@@ -33,9 +36,9 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims.size(), 2, x_dims.size(), 2,
"Input(X) of SequenceEnumerate operator's rank should be 2."); "Input(X) of SequenceEnumerate operator's rank should be 2.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(x_dims[1], 1,
x_dims[1], 1, "Input(X) of SequenceEnumerate operator's 2nd "
"Input(X) of SequenceEnumerate operator's 2nd dimension should be 1."); "dimension should be 1.");
const auto win_size = ctx->Attrs().Get<int>("win_size"); const auto win_size = ctx->Attrs().Get<int>("win_size");
ctx->SetOutputDim("Out", {x_dims[0], win_size}); ctx->SetOutputDim("Out", {x_dims[0], win_size});
......
...@@ -65,6 +65,7 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> { ...@@ -65,6 +65,7 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> {
auto lod0 = in_lod[0]; auto lod0 = in_lod[0];
auto in_len = in->numel(); auto in_len = in->numel();
auto in_data = in->data<T>(); auto in_data = in->data<T>();
out->Resize({in_dims[0], win_size});
auto out_data = out->mutable_data<T>(context.GetPlace()); auto out_data = out->mutable_data<T>(context.GetPlace());
// Copy LoD to GPU // Copy LoD to GPU
const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace()); const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace());
...@@ -72,6 +73,7 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> { ...@@ -72,6 +73,7 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> {
CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1, CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>( PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data); in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data);
out->set_lod(in->lod());
} }
}; };
......
...@@ -39,6 +39,7 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> { ...@@ -39,6 +39,7 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
// Generate enumerate sequence set // Generate enumerate sequence set
auto lod0 = in_lod[0]; auto lod0 = in_lod[0];
auto in_data = in->data<T>(); auto in_data = in->data<T>();
out->Resize({in_dims[0], win_size});
auto out_data = out->mutable_data<T>(context.GetPlace()); auto out_data = out->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < lod0.size() - 1; ++i) { for (size_t i = 0; i < lod0.size() - 1; ++i) {
for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) { for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) {
...@@ -49,6 +50,7 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> { ...@@ -49,6 +50,7 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
} }
} }
} }
out->set_lod(in->lod());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册