提交 1c58eee9 编写于 作者: L luotao1 提交者: ceci3

refine infershape of sequence_enumerate, hash and fuse_emb_seq_pool

test=develop
上级 3f4aeed5
......@@ -23,6 +23,9 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input W of FusedEmbeddingSeqPoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"),
......@@ -42,36 +45,15 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
// we only support sum now
PADDLE_ENFORCE_EQ(combiner, "sum");
int64_t last_dim = table_dims[1];
for (int i = 1; i != ids_dims.size(); ++i) {
last_dim *= ids_dims[i];
}
if (ctx->IsRuntime()) {
framework::Variable* ids_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Ids")[0]);
const auto& ids_lod = ids_var->Get<LoDTensor>().lod();
int64_t last_dim = FusedEmbeddingSeqPoolLastDim(table_dims, ids_dims);
// in compile time, the lod level of ids must be 1
framework::VarDesc* ids_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1);
// in run time, the LoD of ids must be 1
PADDLE_ENFORCE(ids_lod.size(), 1u,
"The LoD level of Input(Ids) must be 1");
PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty");
int64_t batch_size = ids_lod[0].size() - 1;
// in run time, the shape from Ids -> output
// should be [seq_length, 1] -> [batch_size, embedding_size]
ctx->SetOutputDim("Out", framework::make_ddim({batch_size, last_dim}));
} else {
// in compile time, the lod level of ids must be 1
framework::VarDesc* ids_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1);
// in compile time, the shape from Ids -> output
// should be [-1, 1] -> [-1, embedding_size]
ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim}));
}
// in compile time, the shape from Ids -> output
// should be [-1, 1] -> [-1, embedding_size]
ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim}));
}
protected:
......
......@@ -61,6 +61,15 @@ struct EmbeddingVSumFunctor {
}
};
inline int FusedEmbeddingSeqPoolLastDim(const framework::DDim &table_dims,
const framework::DDim &ids_dims) {
int64_t last_dim = table_dims[1];
for (int i = 1; i != ids_dims.size(); ++i) {
last_dim *= ids_dims[i];
}
return last_dim;
}
template <typename T>
class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
public:
......@@ -70,6 +79,17 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
const LoDTensor *table_var = context.Input<LoDTensor>("W");
const std::string &combiner_type = context.Attr<std::string>("combiner");
int64_t last_dim =
FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims());
const auto &ids_lod = ids_t->lod();
// in run time, the LoD of ids must be 1
PADDLE_ENFORCE(ids_lod.size(), 1u, "The LoD level of Input(Ids) must be 1");
PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty");
int64_t batch_size = ids_lod[0].size() - 1;
// in run time, the shape from Ids -> output
// should be [seq_length, 1] -> [batch_size, embedding_size]
output_t->Resize({batch_size, last_dim});
if (combiner_type == "sum") {
EmbeddingVSumFunctor<T> functor;
functor(context, table_var, ids_t, output_t);
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/hash_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
......@@ -27,6 +26,9 @@ class HashOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of HashOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -36,15 +38,8 @@ class HashOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(dims.size(), 2UL,
"The input of hash_op's dimensions must be 2");
std::vector<int64_t> out_dims;
out_dims.reserve(dims.size() + 1);
// copy all dims except the last one
for (int i = 0u; i != dims.size() - 1; ++i) {
out_dims.emplace_back(dims[i]);
}
int num_hash = ctx->Attrs().Get<int>("num_hash");
out_dims.emplace_back(num_hash);
// keep the last dim to 1
out_dims.emplace_back(1);
HashOutputSize(dims, out_dims, num_hash);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
ctx->ShareLoD("X", /*->*/ "Out");
......@@ -71,4 +66,4 @@ $$Out = scale * X$$
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(hash, ops::HashOp, ops::HashOpMaker);
REGISTER_OP_CPU_KERNEL(hash, ops::HashKerel<int>, ops::HashKerel<int64_t>);
REGISTER_OP_CPU_KERNEL(hash, ops::HashKernel<int>, ops::HashKernel<int64_t>);
......@@ -17,21 +17,34 @@ limitations under the License. */
extern "C" {
#include <xxhash.h>
}
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
// template <typename DeviceContext, typename T>
inline void HashOutputSize(const framework::DDim& in_dims,
std::vector<int64_t>& out_dims, // NOLINT
int num_hash) {
out_dims.reserve(in_dims.size() + 1);
// copy all dims except the last one
for (int i = 0u; i != in_dims.size() - 1; ++i) {
out_dims.emplace_back(in_dims[i]);
}
out_dims.emplace_back(num_hash);
// keep the last dim to 1
out_dims.emplace_back(1);
}
template <typename T>
class HashKerel : public framework::OpKernel<T> {
class HashKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
auto* out_t = context.Output<framework::LoDTensor>("Out");
auto* in_t = context.Input<framework::LoDTensor>("X");
int mod_by = context.Attr<int>("mod_by");
int num_hash = context.Attr<int>("num_hash");
auto* output = out_t->mutable_data<T>(context.GetPlace());
auto in_dims = in_t->dims();
auto in_lod = in_t->lod();
......@@ -39,6 +52,11 @@ class HashKerel : public framework::OpKernel<T> {
static_cast<uint64_t>(in_dims[0]), in_lod[0].back(),
"The actual input data's size mismatched with LoD information.");
std::vector<int64_t> out_dims;
HashOutputSize(in_dims, out_dims, num_hash);
out_t->Resize(framework::make_ddim(out_dims));
auto* output = out_t->mutable_data<T>(context.GetPlace());
auto seq_length = in_dims[0];
auto last_dim = in_dims[in_dims.size() - 1];
auto* input = in_t->data<T>();
......@@ -49,6 +67,7 @@ class HashKerel : public framework::OpKernel<T> {
}
input += last_dim;
}
out_t->set_lod(in_t->lod());
}
};
......
......@@ -22,6 +22,9 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (ctx->IsRuntime()) {
return;
}
PADDLE_ENFORCE(
ctx->HasInput("X"),
"Input(X) of SequecceEnumerate operator should not be null.");
......@@ -33,9 +36,9 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(
x_dims.size(), 2,
"Input(X) of SequenceEnumerate operator's rank should be 2.");
PADDLE_ENFORCE_EQ(
x_dims[1], 1,
"Input(X) of SequenceEnumerate operator's 2nd dimension should be 1.");
PADDLE_ENFORCE_EQ(x_dims[1], 1,
"Input(X) of SequenceEnumerate operator's 2nd "
"dimension should be 1.");
const auto win_size = ctx->Attrs().Get<int>("win_size");
ctx->SetOutputDim("Out", {x_dims[0], win_size});
......
......@@ -65,6 +65,7 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> {
auto lod0 = in_lod[0];
auto in_len = in->numel();
auto in_data = in->data<T>();
out->Resize({in_dims[0], win_size});
auto out_data = out->mutable_data<T>(context.GetPlace());
// Copy LoD to GPU
const size_t* dev_in_lod_ptr = lod0.CUDAData(context.GetPlace());
......@@ -72,6 +73,7 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> {
CalcOutPut<<<(in_len - 1) / PADDLE_CUDA_NUM_THREADS + 1,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_data, dev_in_lod_ptr, lod0.size(), win_size, pad_value, out_data);
out->set_lod(in->lod());
}
};
......
......@@ -39,6 +39,7 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
// Generate enumerate sequence set
auto lod0 = in_lod[0];
auto in_data = in->data<T>();
out->Resize({in_dims[0], win_size});
auto out_data = out->mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < lod0.size() - 1; ++i) {
for (size_t idx = lod0[i]; idx < lod0[i + 1]; ++idx) {
......@@ -49,6 +50,7 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
}
}
}
out->set_lod(in->lod());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册