未验证 提交 434caab2 编写于 作者: C colourful-tree 提交者: GitHub

Merge pull request #16741 from colourful-tree/dev

add continuous value model op
......@@ -241,6 +241,7 @@ paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output
paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329'))
paddle.fluid.layers.pixel_shuffle (ArgSpec(args=['x', 'upscale_factor'], varargs=None, keywords=None, defaults=None), ('document', '731b21c62a4add60a33bd76d802ffc5c'))
paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393'))
paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', 'a07a44c2bacdcd09c1f5f35a96a0514e'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '33bbd42027d872b3818b3d64ec52e139'))
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'b1ae2e1cc0750e58726374061ea90ecc'))
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', 'b0a1c2fc51c27a106da28f3308c41f5e'))
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cvm_op.h"
#include <memory>
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class CVMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("CVM"), "Input(CVM) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto cvm_dims = ctx->GetInputDim("CVM");
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(cvm_dims.size(), 2UL, "Input(CVM)'s rank should be 2.");
PADDLE_ENFORCE_EQ(cvm_dims[1], 2UL,
"The 2nd dimension of "
"Input(CVM) should be 2.");
if (ctx->Attrs().Get<bool>("use_cvm")) {
ctx->SetOutputDim("Y", {x_dims[0], x_dims[1]});
} else {
ctx->SetOutputDim("Y", {x_dims[0], x_dims[1] - 2});
}
ctx->ShareLoD("X", /*->*/ "Y");
}
protected:
// Explicitly set that the data type of computation kernel of
// cvm
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
platform::CPUPlace());
}
};
class CVMGradientOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("CVM"), "Input(CVM) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto cvm_dims = ctx->GetInputDim("CVM");
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2.");
PADDLE_ENFORCE_EQ(cvm_dims.size(), 2, "Input(CVM)'s rank should be 2.");
PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0],
"The 1st dimension of Input(X) and Input(Y@Grad) should "
"be equal.");
PADDLE_ENFORCE_EQ(cvm_dims[1], 2,
"When Attr(soft_label) == false, the 2nd dimension of "
"Input(CVM) should be 2.");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("X", framework::GradVarName("X"));
}
protected:
// Explicitly set that the data type of computation kernel of
// cvm
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
platform::CPUPlace());
}
};
class CVMOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(LodTensor, default LodTensor<float>), a 2-D tensor with shape "
"[N x D],"
" where N is the batch size and D is the emebdding dim. ");
AddInput("CVM",
"(Tensor), a 2-D Tensor with shape [N x 2], where N is the batch "
"size, 2 is show and click.");
AddOutput("Y",
"(LodTensor, default LodTensor<float>), a 2-D tensor with shape "
"[N x K].");
AddAttr<bool>("use_cvm", "bool, use cvm or not").SetDefault(true);
AddComment(R"DOC(
CVM Operator.
We assume that input X is a embedding vector with cvm_feature(show and click), which shape is [N * D] (D is 2(cvm_feature) + embedding dim, N is batch_size)
if use_cvm is True, we will log(cvm_feature), and output shape is [N * D].
if use_cvm is False, we will remove cvm_feature from input, and output shape is [N * (D - 2)].
)DOC");
}
};
class CVMGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("cvm_grad");
op->SetInput("X", Input("X"));
op->SetInput("CVM", Input("CVM"));
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(cvm, ops::CVMOp, ops::CVMOpMaker, ops::CVMGradOpDescMaker);
REGISTER_OPERATOR(cvm_grad, ops::CVMGradientOp);
REGISTER_OP_CPU_KERNEL(cvm, ops::CVMOpKernel<float>, ops::CVMOpKernel<double>);
REGISTER_OP_CPU_KERNEL(cvm_grad, ops::CVMGradOpKernel<float>,
ops::CVMGradOpKernel<double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T>
class CVMOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const LoDTensor* x = context.Input<LoDTensor>("X");
const T* x_data = x->data<T>();
auto lod = x->lod()[0];
int64_t item_size = x->numel() / x->dims()[0];
int offset = 2;
if (!context.Attr<bool>("use_cvm")) {
item_size -= offset;
}
LoDTensor* y = context.Output<LoDTensor>("Y");
T* y_data = y->mutable_data<T>(context.GetPlace());
int seq_num = static_cast<int>(lod.size()) - 1;
for (int i = 0; i < seq_num; ++i) {
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
for (int j = 0; j < seq_len; ++j) {
if (context.Attr<bool>("use_cvm")) {
std::memcpy(y_data, x_data, item_size * sizeof(T));
y_data[0] = log(y_data[0] + 1);
y_data[1] = log(y_data[1] + 1) - y_data[0];
x_data += item_size;
y_data += item_size;
} else {
std::memcpy(y_data, x_data + offset, item_size * sizeof(T));
x_data += item_size + offset;
y_data += item_size;
}
}
}
}
};
template <typename T>
class CVMGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
LoDTensor* dx = context.Output<LoDTensor>(framework::GradVarName("X"));
T* dx_data = dx->mutable_data<T>(context.GetPlace());
const Tensor* cvm = context.Input<Tensor>("CVM");
const T* cvm_data = cvm->data<T>();
int offset = 2;
const framework::LoDTensor* dOut =
context.Input<framework::LoDTensor>(framework::GradVarName("Y"));
const T* dout_data = dOut->data<T>();
auto lod = dx->lod()[0];
int64_t item_size = dx->numel() / dx->dims()[0];
if (!context.Attr<bool>("use_cvm")) {
item_size -= offset;
}
int seq_num = static_cast<int>(lod.size()) - 1;
for (int i = 0; i < seq_num; ++i) {
int64_t seq_len = static_cast<int64_t>(lod[i + 1] - lod[i]);
for (int j = 0; j < seq_len; ++j) {
if (context.Attr<bool>("use_cvm")) {
std::memcpy(dx_data, dout_data, item_size * sizeof(T));
dx_data[0] = cvm_data[0];
dx_data[1] = cvm_data[1];
dx_data += item_size;
dout_data += item_size;
} else {
std::memcpy(dx_data + offset, dout_data, item_size * sizeof(T));
dx_data[0] = cvm_data[0];
dx_data[1] = cvm_data[1];
dx_data += item_size + offset;
dout_data += item_size;
}
}
cvm_data += offset;
}
}
};
} // namespace operators
} // namespace paddle
......@@ -196,6 +196,7 @@ __all__ = [
'npair_loss',
'pixel_shuffle',
'fsp_matrix',
'continuous_value_model',
]
kIgnoreIndex = -100
......@@ -11202,3 +11203,54 @@ def fsp_matrix(x, y):
input_param_name='x'))
helper.append_op(type='fsp', inputs={'X': x, 'Y': y}, outputs={'Out': out})
return out
def continuous_value_model(input, cvm, use_cvm=True):
"""
**continuous_value_model layers**
continuous value model(cvm). Now, it only considers show and click value in CTR project.
We assume that input is an embedding vector with cvm_feature, whose shape is [N * D] (D is 2 + embedding dim).
If use_cvm is True, it will log(cvm_feature), and output shape is [N * D].
If use_cvm is False, it will remove cvm_feature from input, and output shape is [N * (D - 2)].
This layer accepts a tensor named input which is ID after embedded(lod level is 1), cvm is a show_click info.
Args:
input (Variable): a 2-D LodTensor with shape [N x D], where N is the batch size, D is 2 + the embedding dim. lod level = 1.
cvm (Variable): a 2-D Tensor with shape [N x 2], where N is the batch size, 2 is show and click.
use_cvm (bool): use cvm or not. if use cvm, the output dim is the same as input
if don't use cvm, the output dim is input dim - 2(remove show and click)
(cvm op is a customized op, which input is a sequence has embedd_with_cvm default, so we need an op named cvm to decided whever use it or not.)
Returns:
Variable: A 2-D LodTensor with shape [N x D], if use cvm, D is equal to input dim, if don't use cvm, D is equal to input dim - 2.
Examples:
.. code-block:: python
input = fluid.layers.data(name="input", shape=[-1, 1], lod_level=1, append_batch_size=False, dtype="int64")#, stop_gradient=False)
label = fluid.layers.data(name="label", shape=[-1, 1], append_batch_size=False, dtype="int64")
embed = fluid.layers.embedding(
input=input,
size=[100, 11],
dtype='float32')
ones = fluid.layers.fill_constant_batch_size_like(input=label, shape=[-1, 1], dtype="int64", value=1)
show_clk = fluid.layers.cast(fluid.layers.concat([ones, label], axis=1), dtype='float32')
show_clk.stop_gradient = True
input_with_cvm = fluid.layers.continuous_value_model(embed, show_clk, True)
"""
helper = LayerHelper('cvm', **locals())
out = helper.create_variable(dtype=input.dtype)
helper.append_op(
type='cvm',
inputs={'X': [input],
'CVM': [cvm]},
outputs={'Y': [out]},
attrs={"use_cvm": use_cvm})
return out
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from math import log
from math import exp
from op_test import OpTest
import unittest
class TestCVMOp(OpTest):
"""
Test cvm op with discrete one-hot labels.
"""
def setUp(self):
self.op_type = "cvm"
batch_size = 4
dims = 11
lod = [[1]]
self.inputs = {
'X': (np.random.uniform(0, 1, [1, dims]).astype("float32"), lod),
'CVM': np.array([[0.6, 0.4]]).astype("float32"),
}
self.attrs = {'use_cvm': False}
out = []
for index, emb in enumerate(self.inputs["X"][0]):
out.append(emb[2:])
self.outputs = {'Y': (np.array(out), lod)}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册