From 54dddee37e620639610bf410421bac2a22a791a2 Mon Sep 17 00:00:00 2001 From: heqiaozhi Date: Tue, 9 Apr 2019 20:05:03 +0800 Subject: [PATCH] add continuous value model op test=develop --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/cvm_op.cc | 163 +++++++++++++++++++++++++++++++ paddle/fluid/operators/cvm_op.h | 105 ++++++++++++++++++++ python/paddle/fluid/layers/nn.py | 45 +++++++++ 4 files changed, 314 insertions(+) create mode 100644 paddle/fluid/operators/cvm_op.cc create mode 100644 paddle/fluid/operators/cvm_op.h diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index d71d792b4..2f04a1e9d 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -237,6 +237,7 @@ paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329')) paddle.fluid.layers.pixel_shuffle (ArgSpec(args=['x', 'upscale_factor'], varargs=None, keywords=None, defaults=None), ('document', 'ad669cdf83e72a69ebc5ed79e36486de')) paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393')) +paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', 'f870a9e750f2309f044c24bbdc3f232e')) paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '33bbd42027d872b3818b3d64ec52e139')) paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'b1ae2e1cc0750e58726374061ea90ecc')) paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', 'b0a1c2fc51c27a106da28f3308c41f5e')) diff --git a/paddle/fluid/operators/cvm_op.cc b/paddle/fluid/operators/cvm_op.cc new file mode 100644 index 000000000..dccfaf35a --- /dev/null +++ b/paddle/fluid/operators/cvm_op.cc @@ -0,0 +1,163 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/cvm_op.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class CVMOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("CVM"), "Input(CVM) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto cvm_dims = ctx->GetInputDim("CVM"); + PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(cvm_dims.size(), 2UL, "Input(CVM)'s rank should be 2."); + PADDLE_ENFORCE_EQ(cvm_dims[1], 2UL, + "The 2nd dimension of " + "Input(CVM) should be 2."); + + if (ctx->Attrs().Get("use_cvm")) { + ctx->SetOutputDim("Y", {x_dims[0], x_dims[1]}); + } else { + ctx->SetOutputDim("Y", {x_dims[0], x_dims[1] - 2}); + } + ctx->ShareLoD("X", /*->*/ "Y"); + } + + protected: + // Explicitly set that the data type of computation kernel of + // cvm + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class CVMGradientOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); + PADDLE_ENFORCE(ctx->HasInput("CVM"), "Input(CVM) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) should be not null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@GRAD) should be not null."); + + auto x_dims = ctx->GetInputDim("X"); + auto cvm_dims = ctx->GetInputDim("CVM"); + auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2."); + PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2."); + PADDLE_ENFORCE_EQ(cvm_dims.size(), 2, "Input(CVM)'s rank should be 2."); + + PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0], + "The 1st dimension of Input(X) and Input(Y@Grad) should " + "be equal."); + + PADDLE_ENFORCE_EQ(cvm_dims[1], 2, + "When Attr(soft_label) == false, the 2nd dimension of " + "Input(CVM) should be 2."); + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + ctx->ShareLoD("X", framework::GradVarName("X")); + } + + protected: + // Explicitly set that the data type of computation kernel of + // cvm + // is determined by its input "X". + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class CVMOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(LodTensor, default LodTensor), a 2-D tensor with shape " + "[N x D]," + " where N is the batch size and D is the emebdding dim. "); + AddInput("CVM", + "(Tensor), a 2-D Tensor with shape [N x 2], where N is the batch " + "size, 2 is show and click."); + AddOutput("Y", + "(LodTensor, default LodTensor), a 2-D tensor with shape " + "[N x K]."); + AddAttr("use_cvm", "bool, use cvm or not").SetDefault(true); + AddComment(R"DOC( +CVM Operator. + + example: + input = fluid.layers.data(name=\"input\", shape=[-1, 1], lod_level=1, append_batch_size=False, dtype=\"int64\") + label = fluid.layers.data(name=\"label\", shape=[-1, 1], append_batch_size=False, dtype=\"int64\") + + embed = fluid.layers.embedding( + input=input, + size=[100, 11], + dtype='float32') + + ones = fluid.layers.fill_constant_batch_size_like(input=label, shape=[-1, 1], dtype=\"int64\", value=1) + show_clk = fluid.layers.cast(fluid.layers.concat([ones, label], axis=1), dtype='float32') + show_clk.stop_gradient = True + + input_with_cvm = fluid.layers.continuous_value_model(embed, show_clk, True) + +)DOC"); + } +}; +class CVMGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("cvm_grad"); + op->SetInput("X", Input("X")); + op->SetInput("CVM", Input("CVM")); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("CVM"), InputGrad("CVM")); + op->SetAttrMap(Attrs()); + return op; + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(cvm, ops::CVMOp, ops::CVMOpMaker, ops::CVMGradOpDescMaker); + +REGISTER_OPERATOR(cvm_grad, ops::CVMGradientOp); + +REGISTER_OP_CPU_KERNEL(cvm, ops::CVMOpKernel, ops::CVMOpKernel); + +REGISTER_OP_CPU_KERNEL(cvm_grad, ops::CVMGradOpKernel, + ops::CVMGradOpKernel); diff --git a/paddle/fluid/operators/cvm_op.h b/paddle/fluid/operators/cvm_op.h new file mode 100644 index 000000000..ee1199254 --- /dev/null +++ b/paddle/fluid/operators/cvm_op.h @@ -0,0 +1,105 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class CVMOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const LoDTensor* x = context.Input("X"); + const T* x_data = x->data(); + auto lod = x->lod()[0]; + int64_t item_size = x->numel() / x->dims()[0]; + int offset = 2; + if (!context.Attr("use_cvm")) { + item_size -= offset; + } + LoDTensor* y = context.Output("Y"); + T* y_data = y->mutable_data(context.GetPlace()); + + int seq_num = static_cast(lod.size()) - 1; + for (int i = 0; i < seq_num; ++i) { + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + + for (int j = 0; j < seq_len; ++j) { + if (context.Attr("use_cvm")) { + std::memcpy(y_data, x_data, item_size * sizeof(T)); + y_data[0] = log(y_data[0] + 1); + y_data[1] = log(y_data[1] + 1) - y_data[0]; + x_data += item_size; + y_data += item_size; + } else { + std::memcpy(y_data, x_data + offset, item_size * sizeof(T)); + x_data += item_size + offset; + y_data += item_size; + } + } + } + } +}; + +template +class CVMGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + LoDTensor* dx = context.Output(framework::GradVarName("X")); + T* dx_data = dx->mutable_data(context.GetPlace()); + + const Tensor* cvm = context.Input("CVM"); + const T* cvm_data = cvm->data(); + int offset = 2; + const framework::LoDTensor* dOut = + context.Input(framework::GradVarName("Y")); + const T* dout_data = dOut->data(); + + auto lod = dx->lod()[0]; + int64_t item_size = dx->numel() / dx->dims()[0]; + if (!context.Attr("use_cvm")) { + item_size -= offset; + } + + int seq_num = static_cast(lod.size()) - 1; + for (int i = 0; i < seq_num; ++i) { + int64_t seq_len = static_cast(lod[i + 1] - lod[i]); + + for (int j = 0; j < seq_len; ++j) { + if (context.Attr("use_cvm")) { + std::memcpy(dx_data, dout_data, item_size * sizeof(T)); + dx_data[0] = cvm_data[0]; + dx_data[1] = cvm_data[1]; + dx_data += item_size; + dout_data += item_size; + } else { + std::memcpy(dx_data + offset, dout_data, item_size * sizeof(T)); + dx_data[0] = cvm_data[0]; + dx_data[1] = cvm_data[1]; + dx_data += item_size + offset; + dout_data += item_size; + } + } + cvm_data += offset; + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index a5d4d3947..b06a51d7a 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -193,6 +193,7 @@ __all__ = [ 'npair_loss', 'pixel_shuffle', 'fsp_matrix', + 'continuous_value_model', ] kIgnoreIndex = -100 @@ -11062,3 +11063,47 @@ def fsp_matrix(x, y): input_param_name='x')) helper.append_op(type='fsp', inputs={'X': x, 'Y': y}, outputs={'Out': out}) return out + + +def continuous_value_model(input, cvm, use_cvm=True): + """ + **continuous_value_model layers** + continuous value moded(cvm). now, it only consider show and click value in ctr project. + We assume that input is a embedding vector with cvm_feature, which shape is [N * D] (D is 2 + embedding dim) + if use_cvm is True, we will log(cvm_feature), and output shape is [N * D]. + if use_cvm is False, we will remove cvm_feature from inpput, and output shape is [N * (D - 2)]. + + This layer accepts a tensor named input which is ID after embedded and lod level is 1 , + cvm is a show_click info. + Args: + input (Variable): a 2-D LodTensor with shape [N x D], where N is the + batch size, D is 2 + the embedding dim. + lod level = 1. + cvm (Variable): a 2-D Tensor with shape [N x 2], where N is the batch size, 2 is show and click. + use_cvm (bool): use cvm or not. if use cvm, the output dim is the same as input + if don't use cvm, the output dim is input dim - 2(remove show and click). + (cvm op is a customized op, which input is a sequence had embedd_with_cvm default, so we need a op named cvm to decided whever use it or not.) + Returns: + Variable: A 2-D LodTensor with shape [N x D], if use cvm, D is equal to input dim, + if don't use cvm, D is equal to input dim - 2. + Examples: + .. code-block:: python + input = fluid.layers.data(name="input", shape=[-1, 1], lod_level=1, append_batch_size=False, dtype="int64")#, stop_gradient=False) + label = fluid.layers.data(name="label", shape=[-1, 1], append_batch_size=False, dtype="int64") + embed = fluid.layers.embedding( + input=input, + size=[100, 11], + dtype='float32') + ones = fluid.layers.fill_constant_batch_size_like(input=label, shape=[-1, 1], dtype="int64", value=1) + show_clk = fluid.layers.cast(fluid.layers.concat([ones, label], axis=1), dtype='float32') + show_clk.stop_gradient = True + input_with_cvm = fluid.layers.continuous_value_model(embed, show_clk, True) + """ + helper = LayerHelper('cvm', **locals()) + out = helper.create_variable(dtype=input.dtype) + helper.append_op( + type='cvm', + inputs={'X': [input], + 'CVM': [cvm]}, + outputs={'Y': [out]}, + attrs={"use_cvm": use_cvm}) -- GitLab