From 877decdccc3475999abb72ae7383afc24c356cb8 Mon Sep 17 00:00:00 2001 From: xutianbing Date: Wed, 11 Jan 2017 13:31:36 -0800 Subject: [PATCH] merge Daoyuan's FuncArg, address one of the comments. --- paddle/function/CMakeLists.txt | 1 + paddle/function/CosSimOp.cpp | 161 +++++++++----------- paddle/function/CosSimOp.h | 18 +-- paddle/function/CosSimOpGpu.cu | 71 ++++----- paddle/function/CosSimOpTest.cpp | 61 ++++---- paddle/function/FunctionTest.h | 4 + paddle/gserver/layers/CosSimLayer.cpp | 31 ++-- paddle/gserver/layers/CosSimVecMatLayer.cpp | 73 ++++++--- paddle/gserver/layers/CosSimVecMatLayer.h | 54 ------- 9 files changed, 226 insertions(+), 248 deletions(-) delete mode 100644 paddle/gserver/layers/CosSimVecMatLayer.h diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index fae3b7b20..1522510e8 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -27,6 +27,7 @@ if(WITH_TESTING) add_simple_unittest(ContextProjectionOpTest) add_simple_unittest(PadOpTest) add_simple_unittest(MulOpTest) + add_simple_unittest(CosSimOpTest) endif() endif() diff --git a/paddle/function/CosSimOp.cpp b/paddle/function/CosSimOp.cpp index 1c3cb6f18..130ee56f3 100644 --- a/paddle/function/CosSimOp.cpp +++ b/paddle/function/CosSimOp.cpp @@ -27,21 +27,21 @@ namespace paddle { * */ template <> -void CosSimForward(CpuMatrix* out_mat, - const CpuMatrix* in1_mat, - const CpuMatrix* in2_mat, +void CosSimForward(CpuMatrix& out_mat, + const CpuMatrix& in1_mat, + const CpuMatrix& in2_mat, real scale) { - CHECK(out_mat && in1_mat && in2_mat); - size_t num_samples = out_mat->getHeight(); - size_t dim = in1_mat->getWidth(); + CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData()); + size_t num_samples = out_mat.getHeight(); + size_t dim = in1_mat.getWidth(); /// column vector [nSamples, 1] - real* out = out_mat->getData(); - const real* x = in1_mat->getData(); - const real* y = in2_mat->getData(); + real* out = out_mat.getData(); + const real* x = in1_mat.getData(); + const real* y = in2_mat.getData(); /// in2 might only have one row or full rows - CHECK(in2_mat->getHeight() == 1LU || in2_mat->getHeight() == num_samples); - size_t inc = (in2_mat->getHeight() == 1LU) ? 0 : dim; + CHECK(in2_mat.getHeight() == 1LU || in2_mat.getHeight() == num_samples); + size_t inc = (in2_mat.getHeight() == 1LU) ? 0 : dim; for (size_t i = 0; i < num_samples; ++i, x += dim, y += inc) { real square_sum_x = 0; real square_sum_y = 0; @@ -75,26 +75,26 @@ class CosSimForwardFunc : public FunctionBase { scale_ = config.get("scale"); } - void calc(const Arguments& inputs, - const Arguments& outputs, - const Arguments& inouts) override { + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(inputs.size(), 2); CHECK_EQ(outputs.size(), 1); - CHECK_EQ(inouts.size(), 0); - CHECK_EQ(inputs[0].dims_[0], outputs[0].dims_[0]); - CHECK_EQ(inputs[0].dims_[1], inputs[1].dims_[1]); - CHECK_EQ(outputs[0].dims_[1], 1UL); + CHECK_EQ(inputs[0].shape().ndims(), (size_t)2); + CHECK_EQ(inputs[1].shape().ndims(), (size_t)2); + CHECK_EQ(outputs[0].shape().ndims(), (size_t)2); - CHECK(outputs[0].getData() && inputs[0].getData() && inputs[1].getData()); - auto out_mat = std::make_shared::type>( - outputs[0].getData(), outputs[0].dims_[0], outputs[0].dims_[1]); - const auto in1_mat = std::make_shared::type>( - inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]); - const auto in2_mat = std::make_shared::type>( - inputs[1].getData(), inputs[1].dims_[0], inputs[1].dims_[1]); + CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]); + CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]); + CHECK_EQ(outputs[0].shape()[1], 1UL); - CosSimForward(out_mat.get(), in1_mat.get(), in2_mat.get(), scale_); + CHECK(outputs[0].data() && inputs[0].data() && inputs[1].data()); + + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + auto out_mat = outputs[0].matrix(); + const auto in1_mat = inputs[0].matrix(); + const auto in2_mat = inputs[1].matrix(); + + CosSimForward(out_mat, in1_mat, in2_mat, scale_); } private: @@ -116,28 +116,29 @@ private: * \param scale, default 1.0 */ template <> -void CosSimBackward(const CpuMatrix* out_grad, - const CpuMatrix* out_val, - const CpuMatrix* in1_val, - const CpuMatrix* in2_val, - CpuMatrix* in1_grad, - CpuMatrix* in2_grad, +void CosSimBackward(const CpuMatrix& out_grad, + const CpuMatrix& out_val, + const CpuMatrix& in1_val, + const CpuMatrix& in2_val, + CpuMatrix& in1_grad, + CpuMatrix& in2_grad, real scale) { - CHECK(out_grad && out_val && in1_val && in2_val && in1_grad && in2_grad); - CHECK_EQ(out_val->useGpu_, false) << "Matrix type are GPU, CPU required"; - - const real* grad = out_grad->getData(); - const real* out = out_val->getData(); - const real* prev_out_x = in1_val->getData(); - const real* prev_out_y = in2_val->getData(); - real* prev_grad_x = in1_grad->getData(); - real* prev_grad_y = in2_grad->getData(); - - size_t num_samples = out_grad->getHeight(); - size_t dim = in1_val->getWidth(); - CHECK_EQ(in2_val->getHeight(), in2_grad->getHeight()); - CHECK(in2_val->getHeight() == 1LU || in2_val->getHeight() == num_samples); - size_t inc = (in2_val->getHeight() == 1LU) ? 0 : dim; + CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() && + in2_val.getData() && in1_grad.getData() && in2_grad.getData()); + CHECK_EQ(out_val.useGpu_, false) << "Matrix type are GPU, CPU required"; + + const real* grad = out_grad.getData(); + const real* out = out_val.getData(); + const real* prev_out_x = in1_val.getData(); + const real* prev_out_y = in2_val.getData(); + real* prev_grad_x = in1_grad.getData(); + real* prev_grad_y = in2_grad.getData(); + + size_t num_samples = out_grad.getHeight(); + size_t dim = in1_val.getWidth(); + CHECK_EQ(in2_val.getHeight(), in2_grad.getHeight()); + CHECK(in2_val.getHeight() == 1LU || in2_val.getHeight() == num_samples); + size_t inc = (in2_val.getHeight() == 1LU) ? 0 : dim; for (size_t i = 0; i < num_samples; ++i, prev_out_x += dim, prev_out_y += inc, @@ -178,8 +179,8 @@ void CosSimBackward(const CpuMatrix* out_grad, /** * Cosine Similarity backward Derivative * - * \param inouts[0] forward input grad 1, size: nSamples * dim. - * \param inouts[1] forward input grad 2, + * \param outputs[0] forward input grad 1, size: nSamples * dim. + * \param outputs[1] forward input grad 2, * size: n2 * dim (n2 == 1 or n2 == nSamples). * * \param inputs[0] backward loss output grad, size : nSamples * 1. @@ -194,46 +195,36 @@ class CosSimBackwardFunc : public FunctionBase { scale_ = config.get("scale"); } - void calc(const Arguments& inputs, - const Arguments& outputs, - const Arguments& inouts) override { + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(inputs.size(), 4); - CHECK_EQ(outputs.size(), 0); - CHECK_EQ(inouts.size(), 2); + CHECK_EQ(outputs.size(), 2); /// dim of out_grad and out_val == 1, column vector - CHECK_EQ(inputs[0].dims_[1], 1UL); - CHECK_EQ(inputs[1].dims_[1], 1UL); + CHECK_EQ(inputs[0].shape()[1], 1UL); + CHECK_EQ(inputs[1].shape()[1], 1UL); /// nSamples of out_grad == out_val == in_val1 == in_grad1 - CHECK_EQ(inputs[1].dims_[0], inputs[0].dims_[0]); - CHECK_EQ(inputs[0].dims_[0], inputs[0].dims_[0]); - CHECK_EQ(inouts[0].dims_[0], inputs[0].dims_[0]); + CHECK_EQ(inputs[1].shape()[0], inputs[0].shape()[0]); + CHECK_EQ(inputs[0].shape()[0], inputs[0].shape()[0]); + CHECK_EQ(outputs[0].shape()[0], inputs[0].shape()[0]); /// dim of in1_val1 == in_val2 == in_grad1 == in_grad2 - CHECK_EQ(inputs[3].dims_[1], inputs[2].dims_[1]); - CHECK_EQ(inouts[0].dims_[1], inputs[2].dims_[1]); - CHECK_EQ(inouts[1].dims_[1], inputs[2].dims_[1]); - - CHECK(inputs[0].getData() && inputs[1].getData() && inputs[2].getData() && - inputs[3].getData() && inouts[0].getData() && inouts[1].getData()); - const auto out_grad = std::make_shared::type>( - inputs[0].getData(), inputs[0].dims_[0], inputs[0].dims_[1]); - const auto out_val = std::make_shared::type>( - inputs[1].getData(), inputs[1].dims_[0], inputs[1].dims_[1]); - const auto in1_val = std::make_shared::type>( - inputs[2].getData(), inputs[2].dims_[0], inputs[2].dims_[1]); - const auto in2_val = std::make_shared::type>( - inputs[3].getData(), inputs[3].dims_[0], inputs[3].dims_[1]); - auto in1_grad = std::make_shared::type>( - inouts[0].getData(), inouts[0].dims_[0], inouts[0].dims_[1]); - auto in2_grad = std::make_shared::type>( - inouts[1].getData(), inouts[1].dims_[0], inouts[1].dims_[1]); - - CosSimBackward(out_grad.get(), - out_val.get(), - in1_val.get(), - in2_val.get(), - in1_grad.get(), - in2_grad.get(), - scale_); + CHECK_EQ(inputs[3].shape()[1], inputs[2].shape()[1]); + CHECK_EQ(outputs[0].shape()[1], inputs[2].shape()[1]); + CHECK_EQ(outputs[1].shape()[1], inputs[2].shape()[1]); + + CHECK(inputs[0].data() && inputs[1].data() && inputs[2].data() && + inputs[3].data() && outputs[0].data() && outputs[1].data()); + + CHECK_EQ(outputs[0].getArgType(), ADD_TO); + CHECK_EQ(outputs[1].getArgType(), ADD_TO); + + const auto out_grad = inputs[0].matrix(); + const auto out_val = inputs[1].matrix(); + const auto in1_val = inputs[2].matrix(); + const auto in2_val = inputs[3].matrix(); + auto in1_grad = outputs[0].matrix(); + auto in2_grad = outputs[1].matrix(); + + CosSimBackward( + out_grad, out_val, in1_val, in2_val, in1_grad, in2_grad, scale_); } private: diff --git a/paddle/function/CosSimOp.h b/paddle/function/CosSimOp.h index ed1f1e4d5..be73064e6 100644 --- a/paddle/function/CosSimOp.h +++ b/paddle/function/CosSimOp.h @@ -32,9 +32,9 @@ namespace paddle { * */ template -void CosSimForward(typename MatrixT::type* output, - const typename MatrixT::type* input1, - const typename MatrixT::type* input2, +void CosSimForward(typename Tensor::Matrix& output, + const typename Tensor::Matrix& input1, + const typename Tensor::Matrix& input2, real scale); /** @@ -50,12 +50,12 @@ void CosSimForward(typename MatrixT::type* output, * */ template -void CosSimBackward(const typename MatrixT::type* out_grad, - const typename MatrixT::type* out_value, - const typename MatrixT::type* in1_value, - const typename MatrixT::type* in2_value, - typename MatrixT::type* in1_grad, - typename MatrixT::type* in2_grad, +void CosSimBackward(const typename Tensor::Matrix& out_grad, + const typename Tensor::Matrix& out_value, + const typename Tensor::Matrix& in1_value, + const typename Tensor::Matrix& in2_value, + typename Tensor::Matrix& in1_grad, + typename Tensor::Matrix& in2_grad, real scale); } // namespace paddle diff --git a/paddle/function/CosSimOpGpu.cu b/paddle/function/CosSimOpGpu.cu index f654f0bc2..1dd733674 100644 --- a/paddle/function/CosSimOpGpu.cu +++ b/paddle/function/CosSimOpGpu.cu @@ -65,12 +65,12 @@ __global__ void KeCosSim(real* output, } void hlCossim(real* output, - const real* input1, - const real* input2, - size_t width, - size_t input1_height, - size_t input2_height, - real scale) { + const real* input1, + const real* input2, + size_t width, + size_t input1_height, + size_t input2_height, + real scale) { CHECK_NOTNULL(output); CHECK_NOTNULL(input1); CHECK_NOTNULL(input2); @@ -84,20 +84,20 @@ void hlCossim(real* output, } template <> -void CosSimForward(GpuMatrix* out_mat, - const GpuMatrix* in1_mat, - const GpuMatrix* in2_mat, +void CosSimForward(GpuMatrix& out_mat, + const GpuMatrix& in1_mat, + const GpuMatrix& in2_mat, real scale) { - CHECK(out_mat && in1_mat && in2_mat); - CHECK(in1_mat->useGpu_ == true && in2_mat->useGpu_ == true) + CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData()); + CHECK(in1_mat.useGpu_ == true && in2_mat.useGpu_ == true) << "Matrix type are not GPU"; - size_t num_samples = out_mat->getHeight(); - size_t dim = in1_mat->getWidth(); - real* out = out_mat->getData(); - const real* x = in1_mat->getData(); - const real* y = in2_mat->getData(); - hlCossim(out, x, y, dim, in1_mat->getHeight(), in2_mat->getHeight(), scale); + size_t num_samples = out_mat.getHeight(); + size_t dim = in1_mat.getWidth(); + real* out = out_mat.getData(); + const real* x = in1_mat.getData(); + const real* y = in2_mat.getData(); + hlCossim(out, x, y, dim, in1_mat.getHeight(), in2_mat.getHeight(), scale); } template @@ -206,25 +206,26 @@ void hlCossimDerivative(const real* grad, } template <> -void CosSimBackward(const GpuMatrix* out_grad, - const GpuMatrix* out_val, - const GpuMatrix* in1_val, - const GpuMatrix* in2_val, - GpuMatrix* in1_grad, - GpuMatrix* in2_grad, +void CosSimBackward(const GpuMatrix& out_grad, + const GpuMatrix& out_val, + const GpuMatrix& in1_val, + const GpuMatrix& in2_val, + GpuMatrix& in1_grad, + GpuMatrix& in2_grad, real scale) { - CHECK(out_grad && out_val && in1_val && in2_val && in1_grad && in2_grad); - CHECK(out_grad->useGpu_ && out_val->useGpu_ && in1_val->useGpu_ - && in2_val->useGpu_ && in1_grad->useGpu_ && in2_grad->useGpu_) + CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() && + in2_val.getData() && in1_grad.getData() && in2_grad.getData()); + CHECK(out_grad.useGpu_ && out_val.useGpu_ && in1_val.useGpu_ + && in2_val.useGpu_ && in1_grad.useGpu_ && in2_grad.useGpu_) << "Matrix types are not equally GPU"; - size_t dim = in1_val->getWidth(); - const real* grad = out_grad->getData(); - const real* out = out_val->getData(); - const real* prev_out_x = in1_val->getData(); - const real* prev_out_y = in2_val->getData(); - real* prev_grad_x = in1_grad->getData(); - real* prev_grad_y = in2_grad->getData(); + size_t dim = in1_val.getWidth(); + const real* grad = out_grad.getData(); + const real* out = out_val.getData(); + const real* prev_out_x = in1_val.getData(); + const real* prev_out_y = in2_val.getData(); + real* prev_grad_x = in1_grad.getData(); + real* prev_grad_y = in2_grad.getData(); hlCossimDerivative(grad, out, prev_out_x, @@ -232,8 +233,8 @@ void CosSimBackward(const GpuMatrix* out_grad, prev_grad_x, prev_grad_y, dim, - in1_val->getHeight(), - in2_val->getHeight(), + in1_val.getHeight(), + in2_val.getHeight(), scale); } diff --git a/paddle/function/CosSimOpTest.cpp b/paddle/function/CosSimOpTest.cpp index f0e81ee04..dce959e81 100644 --- a/paddle/function/CosSimOpTest.cpp +++ b/paddle/function/CosSimOpTest.cpp @@ -36,16 +36,20 @@ void testCosSimForward(size_t height_x, CpuMatrix cpu_out(height_x, 1); GpuMatrix gpu_out(height_x, 1); - compare.getCpuFunction()->calc( - {Tensor(cpu_arg1.getData(), Dims{height_x, width}), - Tensor(cpu_arg2.getData(), Dims{height_y, width})}, - {Tensor(cpu_out.getData(), Dims{height_x, 1})}, - {}); - compare.getGpuFunction()->calc( - {Tensor(gpu_arg1.getData(), Dims{height_x, width}), - Tensor(gpu_arg2.getData(), Dims{height_y, width})}, - {Tensor(gpu_out.getData(), Dims{height_x, 1})}, - {}); + BufferArgs cpu_inputs; + BufferArgs cpu_outputs; + cpu_inputs.addArg(cpu_arg1); + cpu_inputs.addArg(cpu_arg2); + cpu_outputs.addArg(cpu_out, ASSIGN_TO); + + BufferArgs gpu_inputs; + BufferArgs gpu_outputs; + gpu_inputs.addArg(gpu_arg1); + gpu_inputs.addArg(gpu_arg2); + gpu_outputs.addArg(gpu_out, ASSIGN_TO); + + compare.getCpuFunction()->calc(cpu_inputs, cpu_outputs); + compare.getGpuFunction()->calc(gpu_inputs, gpu_outputs); autotest::TensorCheckErr(cpu_out, gpu_out); } @@ -96,23 +100,26 @@ void testCosSimBackward(size_t height_x, gpu_in1_grad.copyFrom(cpu_in1_grad); gpu_in2_grad.copyFrom(cpu_in2_grad); - compare.getCpuFunction()->calc( - {Tensor(cpu_out_grad.getData(), Dims{height_x, 1}), - Tensor(cpu_out_val.getData(), Dims{height_x, 1}), - Tensor(cpu_in1_val.getData(), Dims{height_x, width}), - Tensor(cpu_in2_val.getData(), Dims{height_x, width})}, - {}, - {Tensor(cpu_in1_grad.getData(), Dims{height_x, width}), - Tensor(cpu_in2_grad.getData(), Dims{height_x, width})}); - - compare.getGpuFunction()->calc( - {Tensor(gpu_out_grad.getData(), Dims{height_x, 1}), - Tensor(gpu_out_val.getData(), Dims{height_x, 1}), - Tensor(gpu_in1_val.getData(), Dims{height_x, width}), - Tensor(gpu_in2_val.getData(), Dims{height_x, width})}, - {}, - {Tensor(gpu_in1_grad.getData(), Dims{height_x, width}), - Tensor(gpu_in2_grad.getData(), Dims{height_x, width})}); + BufferArgs cpu_inputs; + BufferArgs cpu_outputs; + cpu_inputs.addArg(cpu_out_grad); + cpu_inputs.addArg(cpu_out_val); + cpu_inputs.addArg(cpu_in1_val); + cpu_inputs.addArg(cpu_in2_val); + cpu_outputs.addArg(cpu_in1_grad, ADD_TO); + cpu_outputs.addArg(cpu_in2_grad, ADD_TO); + + BufferArgs gpu_inputs; + BufferArgs gpu_outputs; + gpu_inputs.addArg(gpu_out_grad); + gpu_inputs.addArg(gpu_out_val); + gpu_inputs.addArg(gpu_in1_val); + gpu_inputs.addArg(gpu_in2_val); + gpu_outputs.addArg(gpu_in1_grad, ADD_TO); + gpu_outputs.addArg(gpu_in2_grad, ADD_TO); + + compare.getCpuFunction()->calc(cpu_inputs, cpu_outputs); + compare.getGpuFunction()->calc(gpu_inputs, gpu_outputs); autotest::TensorCheckErr(cpu_in1_grad, gpu_in1_grad); autotest::TensorCheckErr(cpu_in2_grad, gpu_in2_grad); diff --git a/paddle/function/FunctionTest.h b/paddle/function/FunctionTest.h index 0cfafdb27..35de3a65d 100644 --- a/paddle/function/FunctionTest.h +++ b/paddle/function/FunctionTest.h @@ -157,6 +157,9 @@ public: cpuSparse_->randomizeUniform(); gpuSparse_->copyFrom(*cpuSparse_, stream); hl_stream_synchronize(stream); + void addInputs(const SequenceArg& input) { + size_t batchSize = input.shape()[0]; + size_t numSeqs = batchSize / 10 + 1; cpuOutputs_.emplace_back( std::make_shared(*cpuSparse_, argType)); @@ -331,6 +334,7 @@ protected: } protected: +<<<<<<< HEAD std::shared_ptr cpuFunc_; std::shared_ptr gpuFunc_; std::vector cpuMemory_; diff --git a/paddle/gserver/layers/CosSimLayer.cpp b/paddle/gserver/layers/CosSimLayer.cpp index b00eda2f6..a6c0300ac 100644 --- a/paddle/gserver/layers/CosSimLayer.cpp +++ b/paddle/gserver/layers/CosSimLayer.cpp @@ -56,13 +56,12 @@ void CosSimLayer::forward(PassType passType) { MatrixPtr prevOut2 = getInputValue(1); CHECK(outV && prevOut1 && prevOut2); - forward_[0]->calc( - {Tensor(prevOut1->getData(), - Dims{prevOut1->getHeight(), prevOut1->getWidth()}), - Tensor(prevOut2->getData(), - Dims{prevOut2->getHeight(), prevOut2->getWidth()})}, - {Tensor(outV->getData(), Dims{outV->getHeight(), outV->getWidth()})}, - {}); + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*prevOut1); + inputs.addArg(*prevOut2); + outputs.addArg(*outV, ASSIGN_TO); + forward_[0]->calc(inputs, outputs); } } @@ -78,14 +77,16 @@ void CosSimLayer::backward(const UpdateCallback& callback) { auto inG1 = this->getInputGrad(0); auto inG2 = this->getInputGrad(1); CHECK(outG && outV && inV1 && inV2 && inG1 && inG2); - backward_[0]->calc( - {Tensor(outG->getData(), Dims{outG->getHeight(), outG->getWidth()}), - Tensor(outV->getData(), Dims{outV->getHeight(), outV->getWidth()}), - Tensor(inV1->getData(), Dims{inV1->getHeight(), inV1->getWidth()}), - Tensor(inV2->getData(), Dims{inV2->getHeight(), inV2->getWidth()})}, - {}, - {Tensor(inG1->getData(), Dims{inG1->getHeight(), inG1->getWidth()}), - Tensor(inG2->getData(), Dims{inG2->getHeight(), inG2->getWidth()})}); + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*outG); + inputs.addArg(*outV); + inputs.addArg(*inV1); + inputs.addArg(*inV2); + outputs.addArg(*inG1, ADD_TO); + outputs.addArg(*inG2, ADD_TO); + + backward_[0]->calc(inputs, outputs); } } diff --git a/paddle/gserver/layers/CosSimVecMatLayer.cpp b/paddle/gserver/layers/CosSimVecMatLayer.cpp index 120c4e84c..29ebe1ca9 100644 --- a/paddle/gserver/layers/CosSimVecMatLayer.cpp +++ b/paddle/gserver/layers/CosSimVecMatLayer.cpp @@ -12,11 +12,44 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "CosSimVecMatLayer.h" +#include "Layer.h" +#include "paddle/math/Matrix.h" #include "paddle/utils/Logging.h" #include "paddle/utils/Stat.h" namespace paddle { +/** + * @brief A layer for computing cosine similarity between a vector + * and each row of a matrix + * out[i] = cos_scale * cos(in1, in2(i,:)); + * @note used in NEURAL TURING MACHINE + * + * Input1: a vector (batchSize * dataDim) + * + * Input2: a matrix in vector form (batchSize * (weightDim*dataDim)) + * + * Output: a vector (batchSize * weightDim) + */ + +class CosSimVecMatLayer : public Layer { +public: + explicit CosSimVecMatLayer(const LayerConfig& config) : Layer(config) {} + + ~CosSimVecMatLayer() {} + + bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + + void forward(PassType passType); + void backward(const UpdateCallback& callback = nullptr); + +protected: + MatrixPtr tmpMtx0; + MatrixPtr tmpMtx1; + MatrixPtr tmpRow0; + MatrixPtr tmpRow1; + MatrixPtr tmpRow2; + MatrixPtr tmpRow3; +}; /** * @brief A layer for computing cosine similarity between a vector @@ -98,7 +131,6 @@ bool CosSimVecMatLayer::init(const LayerMap& layerMap, /* trans= */ false, useGpu_); - /// todo(tianbing), do we really need to check these shared pointers? CHECK(tmpRow0 && tmpRow1 && tmpRow2 && tmpRow3 && tmpMtx0 && tmpMtx1); createFunction(forward_, @@ -136,13 +168,12 @@ void CosSimVecMatLayer::forward(PassType passType) { tmpMtx0->setData(inV1->rowBuf(i)); tmpRow2->setData(outV->rowBuf(i)); - forward_[0]->calc({Tensor(tmpMtx0->getData(), - Dims{tmpMtx0->getHeight(), tmpMtx0->getWidth()}), - Tensor(tmpRow0->getData(), - Dims{tmpRow0->getHeight(), tmpRow0->getWidth()})}, - {Tensor(tmpRow2->getData(), - Dims{tmpRow2->getHeight(), tmpRow2->getWidth()})}, - {}); + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*tmpMtx0); + inputs.addArg(*tmpRow0); + outputs.addArg(*tmpRow2, ASSIGN_TO); + forward_[0]->calc(inputs, outputs); } } @@ -168,20 +199,16 @@ void CosSimVecMatLayer::backward(const UpdateCallback& callback) { tmpRow2->setData(outV->rowBuf(i)); tmpRow3->setData(outG->rowBuf(i)); - backward_[0]->calc( - {Tensor(tmpRow3->getData(), - Dims{tmpRow3->getHeight(), tmpRow3->getWidth()}), - Tensor(tmpRow2->getData(), - Dims{tmpRow2->getHeight(), tmpRow2->getWidth()}), - Tensor(tmpMtx0->getData(), - Dims{tmpMtx0->getHeight(), tmpMtx0->getWidth()}), - Tensor(tmpRow0->getData(), - Dims{tmpRow0->getHeight(), tmpRow0->getWidth()})}, - {}, - {Tensor(tmpMtx1->getData(), - Dims{tmpMtx1->getHeight(), tmpMtx1->getWidth()}), - Tensor(tmpRow1->getData(), - Dims{tmpRow1->getHeight(), tmpRow1->getWidth()})}); + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*tmpRow3); + inputs.addArg(*tmpRow2); + inputs.addArg(*tmpMtx0); + inputs.addArg(*tmpRow0); + outputs.addArg(*tmpMtx1, ADD_TO); + outputs.addArg(*tmpRow1, ADD_TO); + + backward_[0]->calc(inputs, outputs); } } diff --git a/paddle/gserver/layers/CosSimVecMatLayer.h b/paddle/gserver/layers/CosSimVecMatLayer.h deleted file mode 100644 index df4e11848..000000000 --- a/paddle/gserver/layers/CosSimVecMatLayer.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "Layer.h" -#include "paddle/math/Matrix.h" - -namespace paddle { -/** - * @brief A layer for computing cosine similarity between a vector - * and each row of a matrix - * out[i] = cos_scale * cos(in1, in2(i,:)); - * @note used in NEURAL TURING MACHINE - * - * Input1: a vector (batchSize * dataDim) - * - * Input2: a matrix in vector form (batchSize * (weightDim*dataDim)) - * - * Output: a vector (batchSize * weightDim) - */ - -class CosSimVecMatLayer : public Layer { -public: - explicit CosSimVecMatLayer(const LayerConfig& config) : Layer(config) {} - - ~CosSimVecMatLayer() {} - - bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); - - void forward(PassType passType); - void backward(const UpdateCallback& callback = nullptr); - -protected: - MatrixPtr tmpMtx0; - MatrixPtr tmpMtx1; - MatrixPtr tmpRow0; - MatrixPtr tmpRow1; - MatrixPtr tmpRow2; - MatrixPtr tmpRow3; -}; - -} // namespace paddle -- GitLab