diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index a67625fa88fd2fbe4db43241ee824519ceac7017..3a6d1b6a291fdfe2ded193e43e6e15424285df3c 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -38,7 +38,7 @@ const Tensor* GetTensorFromVar(const Variable* var) { return &var->Get(); } PADDLE_ENFORCE(var->IsType(), - "The Input must be LoDTensor or Tensor."); + "The Input must be a LoDTensor or a Tensor."); return &var->Get(); } @@ -47,39 +47,39 @@ Tensor* GetTensorFromVar(Variable* var) { return var->GetMutable(); } PADDLE_ENFORCE(var->IsType(), - "The Input must be LoDTensor or Tensor."); + "The Input must be a LoDTensor or a Tensor."); return var->GetMutable(); } std::string OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); PADDLE_ENFORCE_LE(ins.size(), 1UL, - "Op %s input %s should contain only one variable", type_, - name); + "Operator %s's input %s should contain only one variable.", + type_, name); return ins.empty() ? kEmptyVarName : ins[0]; } const std::vector& OperatorBase::Inputs( const std::string& name) const { auto it = inputs_.find(name); - PADDLE_ENFORCE(it != inputs_.end(), "Op %s do not have input %s", type_, - name); + PADDLE_ENFORCE(it != inputs_.end(), "Operator %s does not have the input %s.", + type_, name); return it->second; } std::string OperatorBase::Output(const std::string& name) const { auto& outs = Outputs(name); PADDLE_ENFORCE_LE(outs.size(), 1UL, - "Op %s output %s should contain only one variable", type_, - name); + "Operator %s's output %s should contain only one variable.", + type_, name); return outs.empty() ? kEmptyVarName : outs[0]; } const std::vector& OperatorBase::Outputs( const std::string& name) const { auto it = outputs_.find(name); - PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output called %s", - type_, name); + PADDLE_ENFORCE(it != outputs_.end(), + "Operator %s does not have an output called %s.", type_, name); return it->second; } diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 4097f92e0216cf5a6f20d2fa7a07751b9f198eb6..d6ef0a80de069f58cc770415501bc5c36d311357 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -108,9 +108,10 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) { if (holder_ != nullptr) { holder_->set_type(type); } - PADDLE_ENFORCE_GT(numel(), 0, - "Tensor's numel must be larger than zero to call " - "Tensor::mutable_data. Call Tensor::set_dim first."); + PADDLE_ENFORCE_GT( + numel(), 0, + "When calling this method, the Tensor's numel must be larger than zero. " + "Please check Tensor::Resize has been called first."); int64_t size = numel() * SizeOfType(type); /* some versions of boost::variant don't have operator!= */ if (holder_ == nullptr || !(holder_->place() == place) || diff --git a/paddle/operators/linear_chain_crf_op.cc b/paddle/operators/linear_chain_crf_op.cc index 65bbfff0f8e7c4dbaf958518f92c7c48bfd6c964..06d71d26be0d05c7d318670731d22f8c83581bc1 100644 --- a/paddle/operators/linear_chain_crf_op.cc +++ b/paddle/operators/linear_chain_crf_op.cc @@ -204,8 +204,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(emission_exps_dims[0], "An empty mini-batch is not allowed."); - auto transition_exps_dims = - ctx->GetInputDim(framework::GradVarName("TransitionExps")); + auto transition_exps_dims = ctx->GetInputDim("TransitionExps"); PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2UL, "The Input(TransitionExps) should be a 2-D tensor."); PADDLE_ENFORCE_EQ( @@ -240,7 +239,8 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { // operator is determined by its input: graidents of LogLikelihood. framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { - return framework::ToDataType(ctx.Input("LogLikelihood")->type()); + return framework::ToDataType( + ctx.Input(framework::GradVarName("LogLikelihood"))->type()); } }; diff --git a/paddle/operators/linear_chain_crf_op.cu b/paddle/operators/linear_chain_crf_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..6fc8995f4c2ce05f89ffb58129695113f89159fa --- /dev/null +++ b/paddle/operators/linear_chain_crf_op.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/linear_chain_crf_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + linear_chain_crf, + ops::LinearChainCRFOpKernel, + ops::LinearChainCRFOpKernel); +REGISTER_OP_GPU_KERNEL( + linear_chain_crf_grad, + ops::LinearChainCRFGradOpKernel, + ops::LinearChainCRFGradOpKernel); diff --git a/paddle/operators/linear_chain_crf_op.h b/paddle/operators/linear_chain_crf_op.h index f028b6554e181d32eef5aaaf6886ddbcf570c970..81b36dd95d7cb8b40af05f4fc0ddea09a6ec86bc 100644 --- a/paddle/operators/linear_chain_crf_op.h +++ b/paddle/operators/linear_chain_crf_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { @@ -47,36 +48,90 @@ template class LinearChainCRFOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* emission_weights = ctx.Input("Emission"); - auto* transition_weights = ctx.Input("Transition"); - auto* emission_exps = ctx.Output("EmissionExps"); - emission_exps->mutable_data(ctx.GetPlace()); - auto* transition_exps = ctx.Output("TransitionExps"); - transition_exps->mutable_data(ctx.GetPlace()); - auto* label = ctx.Input("Label"); - - auto in_lod = emission_weights->lod(); - PADDLE_ENFORCE(in_lod.size(), "Input(Emission) is not a sequence."); - // TODO(caoying) The checks related to LoD information should be // moved into InferShape once after the InferShape is refactored. - PADDLE_ENFORCE_EQ(emission_weights->NumLevels(), 1UL, + PADDLE_ENFORCE_EQ(ctx.Input("Emission")->NumLevels(), 1UL, "The Input(Emission) should be a sequence."); - PADDLE_ENFORCE_EQ(label->NumLevels(), 1UL, + PADDLE_ENFORCE_EQ(ctx.Input("Label")->NumLevels(), 1UL, "The Input(Label) should be a sequence."); + auto in_lod = ctx.Input("Label")->lod(); + PADDLE_ENFORCE(in_lod.size(), "Input(Label) must be a sequence."); const size_t level = 0; + const size_t seq_num = in_lod[level].size() - 1; + + // These local variables hold the inputs and outputs, garanteeing them on + // CPU memory, to provide a consistent reference. + // TODO(caoying) Fix this by moving all these local variables into the + // class's data members once we can profile the whole training process. + LoDTensor* emission_weights = nullptr; + LoDTensor emission_weight_tensor; + Tensor* transition_weights = nullptr; + Tensor transition_weight_tensor; + LoDTensor* label = nullptr; + LoDTensor label_tensor; + + Tensor* emission_exps = nullptr; + Tensor emission_exps_tensor; + Tensor* transition_exps = nullptr; + Tensor transition_exps_tensor; + Tensor* alpha = nullptr; + Tensor alpha_tensor; + Tensor* ll = nullptr; + Tensor ll_tensor; + + if (platform::is_gpu_place(ctx.GetPlace())) { + emission_weights = &emission_weight_tensor; + transition_weights = &transition_weight_tensor; + label = &label_tensor; + + CopyInputsToCpuMemory( + ctx.device_context(), *ctx.Input("Emission"), + *ctx.Input("Transition"), *ctx.Input("Label"), + emission_weights, transition_weights, label); + + emission_exps = &emission_exps_tensor; + emission_exps->Resize(emission_weights->dims()); + + transition_exps = &transition_exps_tensor; + transition_exps->Resize(transition_weights->dims()); + + alpha = &alpha_tensor; + alpha->Resize(ctx.Output("Alpha")->dims()); + + ll = &ll_tensor; + } else { + emission_weights = + const_cast(ctx.Input("Emission")); + transition_weights = const_cast(ctx.Input("Transition")); + label = const_cast(ctx.Input("Label")); + + emission_exps = ctx.Output("EmissionExps"); + transition_exps = ctx.Output("TransitionExps"); + alpha = ctx.Output("Alpha"); + ll = ctx.Output("LogLikelihood"); + } + // Because the computation codes only runs on CPU, here the memory for all + // the outputs is FIXED to be allocated on the CPU memory. + emission_exps->mutable_data(platform::CPUPlace()); + transition_exps->mutable_data(platform::CPUPlace()); + alpha->mutable_data(platform::CPUPlace()); + + // Resize the output tensor to its correct dimension. + ll->Resize({static_cast(seq_num), 1}); + ll->mutable_data(platform::CPUPlace()); + + // Now, all the inputs and outputs should be on the CPU memory. auto emission_dims = emission_weights->dims(); const size_t batch_size = emission_dims[0]; const size_t tag_num = emission_dims[1]; - const size_t seq_num = in_lod[level].size() - 1; Tensor emission_row_max; emission_row_max.mutable_data( framework::make_ddim({static_cast(batch_size), 1}), - ctx.GetPlace()); + platform::CPUPlace()); - auto place = ctx.GetEigenDevice(); + auto place = ctx.GetEigenDevice(); auto x = EigenMatrix::From(*emission_weights); auto x_row_max = EigenMatrix::From(emission_row_max); x_row_max.device(place) = @@ -91,12 +146,7 @@ class LinearChainCRFOpKernel : public framework::OpKernel { auto w_exps = EigenMatrix::From(*transition_exps); w_exps.device(place) = w.exp(); - auto* alpha = ctx.Output("Alpha"); - alpha->mutable_data(ctx.GetPlace()); - auto* ll = ctx.Output("LogLikelihood"); - // resize the output tensor to the correct dimension. - ll->Resize({static_cast(seq_num), 1}); - T* log_likelihood = ll->mutable_data(ctx.GetPlace()); + T* log_likelihood = ll->data(); for (size_t i = 0; i < seq_num; ++i) { int start_pos = static_cast(in_lod[level][i]); int end_pos = static_cast(in_lod[level][i + 1]); @@ -116,9 +166,61 @@ class LinearChainCRFOpKernel : public framework::OpKernel { one_seq, one_seq_row_max, one_seq_exps, *transition_weights, *transition_exps, one_seq_label, &one_seq_alpha); } + + if (platform::is_gpu_place(ctx.GetPlace())) { + CopyOutputsToGpuMemory( + ctx.device_context(), *emission_exps, *transition_exps, *alpha, *ll, + ctx.Output("EmissionExps"), + ctx.Output("TransitionExps"), ctx.Output("Alpha"), + ctx.Output("LogLikelihood")); + } + }; + + private: + void CopyInputsToCpuMemory(const platform::DeviceContext& ctx, + const LoDTensor& emission_weights_src, + const Tensor& transition_weights_src, + const LoDTensor& label_src, + LoDTensor* emission_weights_dst, + Tensor* transition_weights_dst, + LoDTensor* label_dst) const { + // Copy the inputs from GPU memory to CPU memory if this operators runs on + // GPU device. + auto copyLoDTensor = [](const platform::DeviceContext& ctx, + const LoDTensor& src, LoDTensor* dst) { + dst->mutable_data(src.dims(), platform::CPUPlace()); + dst->CopyFrom(src, platform::CPUPlace(), ctx); + + }; + copyLoDTensor(ctx, emission_weights_src, emission_weights_dst); + copyLoDTensor(ctx, label_src, label_dst); + + transition_weights_dst->mutable_data(transition_weights_src.dims(), + platform::CPUPlace()); + transition_weights_dst->CopyFrom(transition_weights_src, + platform::CPUPlace(), ctx); + } + + void CopyOutputsToGpuMemory(const platform::DeviceContext& ctx, + const Tensor& emission_exps_src, + const Tensor& transition_exps_src, + const Tensor& alpha_src, const Tensor& ll_src, + Tensor* emission_exps_dst, + Tensor* transition_exps_dst, Tensor* alpha_dst, + Tensor* ll_dst) const { + // Copy the forward results from CPU memory to GPU memory if this + // operators runs on GPU device. + auto copyTensor = [](const platform::DeviceContext& ctx, const Tensor& src, + Tensor* dst) { + dst->mutable_data(platform::GPUPlace()); + dst->CopyFrom(src, platform::GPUPlace(), ctx); + }; + copyTensor(ctx, emission_exps_src, emission_exps_dst); + copyTensor(ctx, transition_exps_src, transition_exps_dst); + copyTensor(ctx, alpha_src, alpha_dst); + copyTensor(ctx, ll_src, ll_dst); }; - protected: T ForwardOneSequence(const Tensor& emission, const Tensor& emission_row_max, const Tensor& emission_exps, const Tensor& trans_weights, const Tensor& trans_weight_exps, const Tensor& label, @@ -183,35 +285,84 @@ template class LinearChainCRFGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* label = ctx.Input("Label"); - auto* emission_exps = ctx.Input("EmissionExps"); - auto* transition_exps = ctx.Input("TransitionExps"); - auto* alpha = ctx.Input("Alpha"); - const T* ll_grad = - ctx.Input(framework::GradVarName("LogLikelihood"))->data(); - - auto place = ctx.GetPlace(); - auto* emission_grad = - ctx.Output(framework::GradVarName("Emission")); - emission_grad->mutable_data(place); - - auto* trans_grad = ctx.Output(framework::GradVarName("Transition")); - if (trans_grad) { - trans_grad->mutable_data(place); + const size_t level = 0; // currently, only support sequence. + auto lod = ctx.Input("Label")->lod(); + PADDLE_ENFORCE(lod.size(), "Input(Label) must be a sequence."); + + // These local variables hold the inputs and outputs, garanteeing them on + // CPU memory, to provide a consistent reference. + // TODO(caoying) Fix this by moving all these local variables into the + // class's data members once we can profile the training process. + Tensor* label = nullptr; + Tensor label_tensor; + Tensor* emission_exps = nullptr; + Tensor emission_exps_tensor; + Tensor* transition_exps = nullptr; + Tensor transition_exps_tensor; + Tensor* alpha = nullptr; + Tensor alpha_tensor; + Tensor ll_grad_tensor; + T* ll_grad = nullptr; + + Tensor* emission_grad = nullptr; + Tensor emission_grad_tensor; + Tensor* transition_grad = nullptr; + Tensor transition_grad_tensor; + + if (platform::is_gpu_place(ctx.GetPlace())) { + label = &label_tensor; + emission_exps = &emission_exps_tensor; + transition_exps = &transition_exps_tensor; + alpha = &alpha_tensor; + CopyInputsToCpuMemory( + ctx.device_context(), *ctx.Input("Label"), + *ctx.Input("EmissionExps"), + *ctx.Input("TransitionExps"), *ctx.Input("Alpha"), + *ctx.Input(framework::GradVarName("LogLikelihood")), label, + emission_exps, transition_exps, alpha, &ll_grad_tensor); + ll_grad = ll_grad_tensor.data(); + + if (ctx.Output(framework::GradVarName("Emission"))) { + emission_grad = &emission_grad_tensor; + emission_grad->Resize(emission_exps->dims()); + } + + if (ctx.Output(framework::GradVarName("Transition"))) { + transition_grad = &transition_grad_tensor; + transition_grad->Resize(transition_exps->dims()); + } + } else { + label = const_cast(ctx.Input("Label")); + emission_exps = const_cast(ctx.Input("EmissionExps")); + transition_exps = + const_cast(ctx.Input("TransitionExps")); + alpha = const_cast(ctx.Input("Alpha")); + ll_grad = const_cast( + ctx.Input(framework::GradVarName("LogLikelihood"))) + ->data(); + + emission_grad = ctx.Output(framework::GradVarName("Emission")); + transition_grad = + ctx.Output(framework::GradVarName("Transition")); + } + PADDLE_ENFORCE(emission_grad, "Output(Emission@Grad) should not be null."); + emission_grad->mutable_data(platform::CPUPlace()); + math::SetConstant()(ctx.device_context(), + emission_grad, 0.); + if (transition_grad) { + transition_grad->mutable_data(platform::CPUPlace()); + math::SetConstant()(ctx.device_context(), + transition_grad, 0.); } + // Now, all the inputs and outputs should be on the CPU memory. auto emission_dims = emission_exps->dims(); - // Beta is the memo table used in dynamic programming to calculate the // backwark vectors. For a backward vector i (the i-th row of beta), it - // captures the unnormalized probabilities of partial sequences starting at - // position i. + // captures the unnormalized probabilities of partial sequences starting + // at position i. Tensor beta; - beta.mutable_data(emission_dims, place); - - const size_t level = 0; // currently, only support sequence. - auto lod = label->lod(); - PADDLE_ENFORCE(lod.size(), "Input(Label) is not a sequence."); + beta.mutable_data(emission_dims, platform::CPUPlace()); for (size_t i = 0; i < lod[level].size() - 1; ++i) { int start_pos = static_cast(lod[level][i]); @@ -228,11 +379,60 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel { BackwardOneSequence(ctx.device_context(), ll_grad[i], one_seq_emission_exps, *transition_exps, one_seq_alpha, one_seq_label, &one_seq_beta, - trans_grad, &one_seq_emission_grad); + transition_grad, &one_seq_emission_grad); + } + + if (platform::is_gpu_place(ctx.GetPlace())) { + CopyOutputsToGpuMemory( + ctx.device_context(), emission_grad, transition_grad, + ctx.Output(framework::GradVarName("Emission")), + ctx.Output(framework::GradVarName("Transition"))); } }; - protected: + private: + void CopyInputsToCpuMemory(const platform::DeviceContext& ctx, + const LoDTensor& label_src, + const Tensor& emission_exps_src, + const Tensor& transition_exps_src, + const Tensor& alpha_src, const Tensor& ll_grad_src, + Tensor* label_dst, Tensor* emission_exps_dst, + Tensor* transition_exps_dst, Tensor* alpha_dst, + Tensor* ll_grad_dst) const { + // Copy the inputs from GPU memory to CPU memory when this operators runs on + // GPU device. + label_dst->mutable_data(label_src.dims(), platform::CPUPlace()); + label_dst->CopyFrom(label_src, platform::CPUPlace(), ctx); + + auto copyTensor = [](const platform::DeviceContext& ctx, const Tensor& src, + Tensor* dst) { + dst->mutable_data(src.dims(), platform::CPUPlace()); + dst->CopyFrom(src, platform::CPUPlace(), ctx); + }; + copyTensor(ctx, emission_exps_src, emission_exps_dst); + copyTensor(ctx, transition_exps_src, transition_exps_dst); + copyTensor(ctx, alpha_src, alpha_dst); + copyTensor(ctx, ll_grad_src, ll_grad_dst); + }; + + void CopyOutputsToGpuMemory(const platform::DeviceContext& ctx, + const Tensor* emission_grad_src, + const Tensor* transition_grad_src, + Tensor* emission_grad_dst, + Tensor* transition_grad_dst) const { + // Copy the backward results from CPU memory to GPU + // memory if this operators runs on GPU device. + auto copyTensor = [](const platform::DeviceContext& ctx, const Tensor* src, + Tensor* dst) { + if (src && dst) { + dst->mutable_data(platform::GPUPlace()); + dst->CopyFrom(*src, platform::GPUPlace(), ctx); + } + }; + copyTensor(ctx, emission_grad_src, emission_grad_dst); + copyTensor(ctx, transition_grad_src, transition_grad_dst); + }; + void BackwardOneSequence(const platform::DeviceContext& ctx, const T ll_grad, const Tensor& emission_exps, const Tensor& transition_exps, const Tensor& alpha, @@ -255,7 +455,6 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel { beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i]; } NormalizeL1(beta_value + (seq_length - 1) * tag_num, tag_num); - for (int k = static_cast(seq_length) - 2; k >= 0; --k) { for (size_t i = 0; i < tag_num; ++i) { T sum = 0.; @@ -270,10 +469,11 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel { NormalizeL1(beta_value + k * tag_num, tag_num); } + auto x_grad_mat = EigenMatrix::From(*emission_grad); auto alpha_mat = EigenMatrix::From(alpha); auto beta_mat = EigenMatrix::From(*beta); - auto x_grad_mat = EigenMatrix::From(*emission_grad); - auto* place = ctx.GetEigenDevice(); + + auto* place = ctx.GetEigenDevice(); auto prob = alpha_mat * beta_mat; auto row_sum = prob.sum(Eigen::DSizes(1)) .reshape(Eigen::DSizes(seq_length, 1)) @@ -296,7 +496,7 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel { // TODO(caoying): Fix this to avoid using this local variable. Tensor tmp; - tmp.mutable_data(beta->dims(), ctx.GetPlace()); + tmp.mutable_data(beta->dims(), platform::CPUPlace()); auto tmp_mat = EigenMatrix::From(tmp); auto prob = beta_mat * x_exps_mat; auto row_sum = prob.sum(Eigen::DSizes(1))