From 7f5e5ce9fcbc460595c26d07853a008d8ba2cef1 Mon Sep 17 00:00:00 2001 From: phlrain Date: Wed, 26 May 2021 06:16:48 +0000 Subject: [PATCH] first test version --- paddle/fluid/framework/CMakeLists.txt | 3 + paddle/fluid/framework/operator.cc | 3 + paddle/fluid/framework/operator.h | 427 ++++++++++++++++++ paddle/fluid/imperative/CMakeLists.txt | 2 + paddle/fluid/operators/fill_constant_op.cc | 7 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 9 +- .../operators/reduce_ops/reduce_sum_op.cc | 2 + .../softmax_with_cross_entropy_op.cc | 5 +- .../operators/softmax_with_cross_entropy_op.h | 14 +- 9 files changed, 463 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 4644e674ba4..2e9c81caf53 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -383,6 +383,9 @@ cc_library(op_meta_info SRCS ../extension/src/ext_op_meta_info.cc DEPS custom_te cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper custom_tensor op_meta_info) cc_test(custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog) +#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ) +cc_binary(new_executor SRCS new_exec.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) + set(FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator) cc_library(paddle_framework DEPS ${FLUID_FRAMEWORK_MODULES}) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 955c917b2c1..e2ced316ae9 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1026,6 +1026,7 @@ class RuntimeInferShapeContext : public InferShapeContext { const RuntimeContext& ctx_; }; + static void CheckTensorNANOrInf(const std::string& op_type, const std::string& name, const framework::Tensor& tensor) { @@ -1598,7 +1599,9 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType( proto::VarType::Type dafault_data_type = static_cast(-1); proto::VarType::Type data_type = dafault_data_type; + //std::cerr << "par in" << std::endl; ParseInputDataType(ctx, name, &data_type); + // PADDLE_ENFORCE_NE( data_type, dafault_data_type, platform::errors::InvalidArgument( diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 3fc61581eca..2128626a05c 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -37,6 +37,7 @@ limitations under the License. */ #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/variant.h" +#include "paddle/fluid/framework/operator.h" namespace paddle { namespace framework { @@ -572,5 +573,431 @@ class OperatorWithKernel : public OperatorBase { extern bool OpSupportGPU(const std::string& op_type); +/* +class RuntimeInferShapeContext : public InferShapeContext { + public: + RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx) + : op_(op), ctx_(ctx) {} + + bool HasInput(const std::string& name) const override { + // has only one input + const auto& ins = ctx_.inputs; + auto it = ins.find(name); + if (it == ins.end()) { + return false; + } + const auto& in = it->second; + if (in.size() == 0) return false; + PADDLE_ENFORCE_EQ( + in.size(), 1UL, + platform::errors::InvalidArgument( + "Input %s should not contain more than one inputs.", name)); + return in[0] != nullptr; + } + + bool HasOutput(const std::string& name) const override { + // has only one output + const auto& outs = ctx_.outputs; + auto it = outs.find(name); + if (it == outs.end()) { + return false; + } + const auto& out = it->second; + if (out.size() == 0) { + return false; + } + PADDLE_ENFORCE_EQ( + out.size(), 1UL, + platform::errors::InvalidArgument( + "Output %s should not contain more than one outputs.", name)); + return out[0] != nullptr; + } + + bool HasInputs(const std::string& name) const override { + const auto& ins = ctx_.inputs; + auto it = ins.find(name); + if (it == ins.end() || it->second.empty()) { + return false; + } + for (auto& input : it->second) { + if (input == nullptr) { + return false; + } + } + return true; + } + + bool HasOutputs(const std::string& name) const override { + const auto& outs = ctx_.outputs; + auto it = outs.find(name); + if (it == outs.end() || it->second.empty()) { + return false; + } + for (auto& output : it->second) { + if (output == nullptr) { + return false; + } + } + return true; + } + + AttrReader Attrs() const override { return AttrReader(op_.Attrs()); } + + std::vector Inputs(const std::string& name) const override { + return op_.Inputs(name); + } + + std::vector Outputs(const std::string& name) const override { + return op_.Outputs(name); + } + + std::string GetInputNameByIdx(size_t idx) const override { + auto& op_proto = + paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_; + PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(), + platform::errors::OutOfRange( + "The index should be less than the size of inputs of " + "operator %s, but got index is %d and size is %d", + op_.Type(), idx, op_proto->inputs().size())); + return op_proto->inputs()[idx].name(); + } + + std::string GetOutputNameByIdx(size_t idx) const override { + auto& op_proto = + paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_; + PADDLE_ENFORCE_LT( + idx, op_proto->outputs().size(), + platform::errors::OutOfRange( + "The index should be less than the size of outputs of " + "operator %s, but got index is %d and size is %d", + op_.Type(), idx, op_proto->outputs().size())); + return op_proto->outputs()[idx].name(); + } + + void ShareDim(const std::string& in, const std::string& out, size_t i = 0, + size_t j = 0) override { + auto in_it = ctx_.inputs.find(in); + auto out_it = ctx_.outputs.find(out); + PADDLE_ENFORCE_NE( + in_it, ctx_.inputs.end(), + platform::errors::NotFound("Input %s does not exist.", in)); + PADDLE_ENFORCE_NE( + out_it, ctx_.outputs.end(), + platform::errors::NotFound("Output %s does not exist.", out)); + PADDLE_ENFORCE_LT(i, in_it->second.size(), + platform::errors::InvalidArgument( + "The index of input dimension is out of range, " + "excepted index less than %zu, but received %zu.", + in_it->second.size(), i)); + PADDLE_ENFORCE_LT(j, out_it->second.size(), + platform::errors::InvalidArgument( + "The index of output dimension is out of range, " + "excepted index less than %zu, but received %zu.", + out_it->second.size(), j)); + + Variable* in_var = in_it->second[i]; + Variable* out_var = out_it->second[j]; + + PADDLE_ENFORCE_EQ( + in_var->Type(), out_var->Type(), + platform::errors::InvalidArgument( + "The type of input (%s) and output (%s) are inconsistent.", in, + out)); + + if (in_var->IsType()) { + auto& in_sele_rows = in_var->Get(); + auto out_sele_rows = out_var->GetMutable(); + out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims()); + out_sele_rows->set_rows(in_sele_rows.rows()); + out_sele_rows->set_height(in_sele_rows.height()); + } else if (in_var->IsType()) { + auto& in_lod_tensor = in_var->Get(); + auto* out_lod_tensor = out_var->GetMutable(); + out_lod_tensor->Resize(in_lod_tensor.dims()); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Currently, the input type of ShareDim only can be LoDTensor " + "or SelectedRows.")); + } + } + + void ShareAllLoD(const std::string& in, + const std::string& out) const override { + auto in_it = ctx_.inputs.find(in); + auto out_it = ctx_.outputs.find(out); + PADDLE_ENFORCE_NE(in_it, ctx_.inputs.end(), + platform::errors::NotFound( + "Input [%s] found error in Op [%s]", in, op_.Type())); + PADDLE_ENFORCE_NE( + out_it, ctx_.outputs.end(), + platform::errors::NotFound("Output [%s] found error in Op [%s]", out, + op_.Type())); + + auto& in_var_list = in_it->second; + auto& out_var_list = out_it->second; + + PADDLE_ENFORCE_EQ( + in_var_list.size(), out_var_list.size(), + platform::errors::PreconditionNotMet( + "Op [%s]: Input var size should be equal with output var size", + op_.Type())); + + auto& out_var_names = op_.Outputs(out); + + for (size_t i = 0; i < in_var_list.size(); ++i) { + if (out_var_names[i] == framework::kEmptyVarName) { + continue; + } + + Variable* in_var = in_var_list[i]; + if (!in_var->IsType()) return; + Variable* out_var = out_var_list[i]; + PADDLE_ENFORCE_EQ(out_var->IsType(), true, + platform::errors::PreconditionNotMet( + "The %d-th output of Output(%s) must be LoDTensor.", + i, out_var_names[i])); + auto& in_tensor = in_var->Get(); + auto* out_tensor = out_var->GetMutable(); + out_tensor->set_lod(in_tensor.lod()); +#ifdef PADDLE_WITH_MKLDNN + if (in_tensor.layout() != DataLayout::kMKLDNN) +#endif + out_tensor->set_layout(in_tensor.layout()); + } + } + + void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, + size_t j = 0) const override { + auto in_it = ctx_.inputs.find(in); + auto out_it = ctx_.outputs.find(out); + PADDLE_ENFORCE_NE( + in_it, ctx_.inputs.end(), + platform::errors::NotFound("Input %s does not exist.", in)); + PADDLE_ENFORCE_NE( + out_it, ctx_.outputs.end(), + platform::errors::NotFound("Output %s does not exist.", out)); + PADDLE_ENFORCE_LT(i, in_it->second.size(), + platform::errors::InvalidArgument( + "The index of input dimension is out of range, " + "excepted index less than %zu, but received %zu.", + in_it->second.size(), i)); + PADDLE_ENFORCE_LT(j, out_it->second.size(), + platform::errors::InvalidArgument( + "The index of output dimension is out of range, " + "excepted index less than %zu, but received %zu.", + out_it->second.size(), j)); + + Variable* in_var = in_it->second.at(i); + if (!in_var->IsType()) return; + Variable* out_var = out_it->second.at(j); + PADDLE_ENFORCE_EQ( + out_var->IsType(), true, + platform::errors::InvalidArgument( + "The %zu-th output of Output(%s) must be LoDTensor.", j, out)); + auto& in_tensor = in_var->Get(); + auto* out_tensor = out_var->GetMutable(); + out_tensor->set_lod(in_tensor.lod()); + +// TODO(dzhwinter) : reuse ShareLoD in most operators. +// Need to call ShareLayout explicitly in sequence related ops. +// Shall we have a better method to shared info between in/out Tensor? +#ifdef PADDLE_WITH_MKLDNN + // Fix me: ugly workaround below + // Correct solution: + // set_layout() should NOT be called here (i.e. ShareLoD). Instead, + // layout of output tensor should be set "manually" in Compute() + // of each OPKernel. The reason layout should NOT be shared between + // input and output "automatically" (now by InferShape()->ShareLoD()) + // is that layout transform may occur after InferShape(). + // Workaround: + // Skip set_layout() when input layout is kMKLDNN + // This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN + // OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called + // in Compute() + if (in_tensor.layout() != DataLayout::kMKLDNN) +#endif + out_tensor->set_layout(in_tensor.layout()); + } + + int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "GetLoDLevel is only used in compile time. The calculation of " + "output's actual lod is different among operators so that should be " + "set in the runtime kernel.")); + } + + void SetLoDLevel(const std::string& out, int32_t lod_level, + size_t j = 0) const override { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "SetLoDLevel is only used in compile time. The calculation of " + "output's actual lod is different among operators so that should be " + "set in the runtime kernel.")); + } + + bool IsRuntime() const override { return true; } + + // TODO(paddle-dev): Can this be template? + std::vector GetInputVarPtrs( + const std::string& name) override { + const std::vector& vars = InputVars(name); + std::vector res; + res.reserve(vars.size()); + res.insert(res.begin(), vars.begin(), vars.end()); + return res; + } + + std::vector GetOutputVarPtrs( + const std::string& name) override { + const std::vector& vars = OutputVars(name); + std::vector res; + res.reserve(vars.size()); + res.insert(res.begin(), vars.begin(), vars.end()); + return res; + } + + DDim GetInputDim(const std::string& name) const override { + const std::vector& vars = InputVars(name); + PADDLE_ENFORCE_EQ( + vars.size(), 1UL, + platform::errors::InvalidArgument( + "Input(%s) should hold one element, but now it holds %zu elements.", + name, vars.size())); + return this->GetDim(vars[0]); + } + + std::vector GetInputsDim(const std::string& name) const override { + const std::vector& vars = InputVars(name); + return GetDims(vars); + } + + std::vector GetInputsVarType( + const std::string& name) const override { + return GetVarTypes(InputVars(name)); + } + + std::vector GetOutputsVarType( + const std::string& name) const override { + return GetVarTypes(OutputVars(name)); + } + + void SetOutputDim(const std::string& name, const DDim& dim) override { + auto& vars = OutputVars(name); + PADDLE_ENFORCE_EQ( + vars.size(), 1UL, + platform::errors::InvalidArgument("Output(%s) should hold one element, " + "but now it holds %zu elements.", + name, vars.size())); + SetDim(vars[0], dim); + } + + void SetOutputsDim(const std::string& name, + const std::vector& dims) override { + auto& vars = OutputVars(name); + SetDims(vars, dims); + } + + protected: + DDim GetDim(Variable* var) const { + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::InvalidArgument("Input variable is nullptr.")); + if (var->IsType()) { + return var->Get().dims(); + } else if (var->IsType()) { + return var->Get().GetCompleteDims(); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Only LoDTensor or SelectedRows support 'GetDim', but input " + "Variable's type is %s.", + ToTypeName(var->Type()))); + } + } + + std::vector GetDims(const std::vector& vars) const { + std::vector ret; + ret.reserve(vars.size()); + std::transform(vars.begin(), vars.end(), std::back_inserter(ret), + [this](Variable* var) { return this->GetDim(var); }); + return ret; + } + + std::vector GetRepeatedDims(const std::string& name) const override { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "GetRepeatedDims method only ban be used in compile time.")); + } + + void SetDim(Variable* var, const DDim& dim) { + if (var->IsType()) { + var->GetMutable()->Resize(dim); + } else if (var->IsType()) { + var->GetMutable()->set_height(dim[0]); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Variable type error, expect LoDTensor or SelectedRows, but received " + "(%s).", + ToTypeName(var->Type()))); + } + } + + void SetDims(const std::vector& vars, + const std::vector& dims) { + size_t length = vars.size(); + PADDLE_ENFORCE_EQ(length, dims.size(), + platform::errors::InvalidArgument( + "The number of input variables do not match the " + "number of input dimensions, the number of variables " + "is %zu, the number of dimensions is %zu.", + length, dims.size())); + for (size_t i = 0; i < length; ++i) { + if (vars[i] == nullptr) { + continue; + } + SetDim(vars[i], dims[i]); + } + } + + void SetRepeatedDims(const std::string& name, + const std::vector& dims) override { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "SetRepeatedDims method only can be used in compile time.")); + } + + std::vector GetVarTypes( + const std::vector& vars) const { + std::vector retv; + retv.resize(vars.size()); + std::transform(vars.begin(), vars.end(), retv.begin(), + std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType), + this, std::placeholders::_1)); + return retv; + } + + proto::VarType::Type GetVarType(Variable* var) const { + return ToVarType(var->Type()); + } + + private: + const std::vector& InputVars(const std::string& name) const { + auto it = ctx_.inputs.find(name); + PADDLE_ENFORCE_NE( + it, ctx_.inputs.end(), + platform::errors::NotFound( + "Operator (%s) does not have the input (%s).", op_.Type(), name)); + return it->second; + } + + const std::vector& OutputVars(const std::string& name) const { + auto it = ctx_.outputs.find(name); + PADDLE_ENFORCE_NE( + it, ctx_.outputs.end(), + platform::errors::NotFound( + "Operator (%s) does not have the outputs (%s).", op_.Type(), name)); + return it->second; + } + + const OperatorBase& op_; + const RuntimeContext& ctx_; +}; +*/ + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 6bee3d44b2e..02ee315fe14 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -28,4 +28,6 @@ endif(NOT WIN32) cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function) +cc_binary(tracer_test SRCS tracer_test.cc DEPS tracer layer op_registry python pybind ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler ) + add_subdirectory(tests) diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index f35d8b6bbf8..4e576730edc 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fill_constant_op.h" #include #include "paddle/fluid/framework/op_version_registry.h" +#include namespace paddle { namespace operators { @@ -23,9 +24,11 @@ class FillConstantOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "FillConstant"); auto& shape = ctx->Attrs().Get>("shape"); + if (!ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList")) { for (size_t i = 0; i < shape.size(); ++i) { PADDLE_ENFORCE_GE( @@ -36,7 +39,7 @@ class FillConstantOp : public framework::OperatorWithKernel { i, shape[i], framework::make_ddim(shape))); } } - + if (shape.empty() && ctx->HasInput("ShapeTensor")) { auto shape_dims = ctx->GetInputDim("ShapeTensor"); int num_ele = 1; @@ -48,7 +51,9 @@ class FillConstantOp : public framework::OperatorWithKernel { return; } + ctx->SetOutputDim("Out", framework::make_ddim(shape)); + } protected: diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 390c4d9709a..86411f82620 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -559,12 +559,16 @@ class ReduceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { + //std::cerr << "core here" << std::endl; + + int in_dtype = ctx.Attr("in_dtype"); auto input_data_type = (in_dtype >= 0) ? static_cast(in_dtype) : OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - + //std::cerr << "sum 1" << std::endl; + /* #ifdef PADDLE_WITH_MKLDNN auto CanMKLDNNReduceGradBeUsed = [&]() { auto dx_dims = ctx.Input("X")->dims(); @@ -580,7 +584,8 @@ class ReduceGradOp : public framework::OperatorWithKernel { framework::LibraryType::kMKLDNN); } #endif - + */ + //std::cerr << "sum 2" << std::endl; return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index 5a8e8894e1c..46a1fe6af93 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #include +#include namespace paddle { namespace framework { @@ -51,6 +52,7 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const { + std::cerr << "get exec" << std::endl; int in_dtype = ctx.Attr("in_dtype"); if (in_dtype >= 0) { return framework::OpKernelType( diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index fbaf76d4e7c..2b57b6f296a 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -269,9 +269,12 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + //std::cerr << "softmax here" << std::endl; + auto res = framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Loss")), ctx.device_context()); + //std::cerr << "softmax end" << std::endl; + return res; } }; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.h b/paddle/fluid/operators/softmax_with_cross_entropy_op.h index 74316841a13..87bcf941e37 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.h +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.h @@ -28,6 +28,7 @@ template class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE_EQ( platform::is_cpu_place(context.GetPlace()), true, platform::errors::Unimplemented("This kernel only runs on CPU.")); @@ -106,20 +107,23 @@ template class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + const Tensor* out_grad = context.Input(framework::GradVarName("Loss")); + const Tensor* labels = context.Input("Label"); Tensor* logit_grad = context.Output(framework::GradVarName("Logits")); - + const Tensor* softmax = context.Input("Softmax"); + const bool use_softmax = context.Attr("use_softmax"); - + if (logit_grad != softmax || !use_softmax) { framework::TensorCopy(*softmax, context.GetPlace(), context.device_context(), logit_grad); } - + const bool soft_label = context.Attr("soft_label"); auto ignore_index = context.Attr("ignore_index"); @@ -133,7 +137,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { logit_grad_2d.ShareDataWith(*logit_grad).Resize({n, d}); labels_2d.ShareDataWith(*labels).Resize({n, labels->numel() / n}); out_grad_2d.ShareDataWith(*out_grad).Resize({n, d / axis_dim}); - + auto out_grad_mat = framework::EigenMatrix::From(out_grad_2d); auto logit_grad_mat = framework::EigenMatrix::From(logit_grad_2d); auto& place = *context.template device_context() @@ -180,7 +184,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { } return; } - + // for use_softmax=False, continue if (soft_label) { -- GitLab