Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7f5e5ce9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7f5e5ce9
编写于
5月 26, 2021
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
first test version
上级
eeca9639
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
463 addition
and
9 deletion
+463
-9
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+3
-0
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+3
-0
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+427
-0
paddle/fluid/imperative/CMakeLists.txt
paddle/fluid/imperative/CMakeLists.txt
+2
-0
paddle/fluid/operators/fill_constant_op.cc
paddle/fluid/operators/fill_constant_op.cc
+6
-1
paddle/fluid/operators/reduce_ops/reduce_op.h
paddle/fluid/operators/reduce_ops/reduce_op.h
+7
-2
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
+2
-0
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
+4
-1
paddle/fluid/operators/softmax_with_cross_entropy_op.h
paddle/fluid/operators/softmax_with_cross_entropy_op.h
+9
-5
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
7f5e5ce9
...
...
@@ -383,6 +383,9 @@ cc_library(op_meta_info SRCS ../extension/src/ext_op_meta_info.cc DEPS custom_te
cc_library
(
custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper custom_tensor op_meta_info
)
cc_test
(
custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog
)
#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} )
cc_binary
(
new_executor SRCS new_exec.cc DEPS operator op_registry executor
${
GLOB_OP_LIB
}
${
GLOB_OPERATOR_DEPS
}
profiler
)
set
(
FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator
)
cc_library
(
paddle_framework DEPS
${
FLUID_FRAMEWORK_MODULES
}
)
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
7f5e5ce9
...
...
@@ -1026,6 +1026,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
const
RuntimeContext
&
ctx_
;
};
static
void
CheckTensorNANOrInf
(
const
std
::
string
&
op_type
,
const
std
::
string
&
name
,
const
framework
::
Tensor
&
tensor
)
{
...
...
@@ -1598,7 +1599,9 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType(
proto
::
VarType
::
Type
dafault_data_type
=
static_cast
<
proto
::
VarType
::
Type
>
(
-
1
);
proto
::
VarType
::
Type
data_type
=
dafault_data_type
;
//std::cerr << "par in" << std::endl;
ParseInputDataType
(
ctx
,
name
,
&
data_type
);
//
PADDLE_ENFORCE_NE
(
data_type
,
dafault_data_type
,
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/framework/operator.h
浏览文件 @
7f5e5ce9
...
...
@@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -572,5 +573,431 @@ class OperatorWithKernel : public OperatorBase {
extern
bool
OpSupportGPU
(
const
std
::
string
&
op_type
);
/*
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx)
: op_(op), ctx_(ctx) {}
bool HasInput(const std::string& name) const override {
// has only one input
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end()) {
return false;
}
const auto& in = it->second;
if (in.size() == 0) return false;
PADDLE_ENFORCE_EQ(
in.size(), 1UL,
platform::errors::InvalidArgument(
"Input %s should not contain more than one inputs.", name));
return in[0] != nullptr;
}
bool HasOutput(const std::string& name) const override {
// has only one output
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end()) {
return false;
}
const auto& out = it->second;
if (out.size() == 0) {
return false;
}
PADDLE_ENFORCE_EQ(
out.size(), 1UL,
platform::errors::InvalidArgument(
"Output %s should not contain more than one outputs.", name));
return out[0] != nullptr;
}
bool HasInputs(const std::string& name) const override {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end() || it->second.empty()) {
return false;
}
for (auto& input : it->second) {
if (input == nullptr) {
return false;
}
}
return true;
}
bool HasOutputs(const std::string& name) const override {
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end() || it->second.empty()) {
return false;
}
for (auto& output : it->second) {
if (output == nullptr) {
return false;
}
}
return true;
}
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
std::vector<std::string> Inputs(const std::string& name) const override {
return op_.Inputs(name);
}
std::vector<std::string> Outputs(const std::string& name) const override {
return op_.Outputs(name);
}
std::string GetInputNameByIdx(size_t idx) const override {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(), idx, op_proto->inputs().size()));
return op_proto->inputs()[idx].name();
}
std::string GetOutputNameByIdx(size_t idx) const override {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(
idx, op_proto->outputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of outputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(), idx, op_proto->outputs().size()));
return op_proto->outputs()[idx].name();
}
void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(
in_it, ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
out_it, ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i, in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(), i));
PADDLE_ENFORCE_LT(j, out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(), j));
Variable* in_var = in_it->second[i];
Variable* out_var = out_it->second[j];
PADDLE_ENFORCE_EQ(
in_var->Type(), out_var->Type(),
platform::errors::InvalidArgument(
"The type of input (%s) and output (%s) are inconsistent.", in,
out));
if (in_var->IsType<framework::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height());
} else if (in_var->IsType<framework::LoDTensor>()) {
auto& in_lod_tensor = in_var->Get<framework::LoDTensor>();
auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims());
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, the input type of ShareDim only can be LoDTensor "
"or SelectedRows."));
}
}
void ShareAllLoD(const std::string& in,
const std::string& out) const override {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it, ctx_.inputs.end(),
platform::errors::NotFound(
"Input [%s] found error in Op [%s]", in, op_.Type()));
PADDLE_ENFORCE_NE(
out_it, ctx_.outputs.end(),
platform::errors::NotFound("Output [%s] found error in Op [%s]", out,
op_.Type()));
auto& in_var_list = in_it->second;
auto& out_var_list = out_it->second;
PADDLE_ENFORCE_EQ(
in_var_list.size(), out_var_list.size(),
platform::errors::PreconditionNotMet(
"Op [%s]: Input var size should be equal with output var size",
op_.Type()));
auto& out_var_names = op_.Outputs(out);
for (size_t i = 0; i < in_var_list.size(); ++i) {
if (out_var_names[i] == framework::kEmptyVarName) {
continue;
}
Variable* in_var = in_var_list[i];
if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = out_var_list[i];
PADDLE_ENFORCE_EQ(out_var->IsType<LoDTensor>(), true,
platform::errors::PreconditionNotMet(
"The %d-th output of Output(%s) must be LoDTensor.",
i, out_var_names[i]));
auto& in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod());
#ifdef PADDLE_WITH_MKLDNN
if (in_tensor.layout() != DataLayout::kMKLDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(
in_it, ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
out_it, ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i, in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(), i));
PADDLE_ENFORCE_LT(j, out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(), j));
Variable* in_var = in_it->second.at(i);
if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = out_it->second.at(j);
PADDLE_ENFORCE_EQ(
out_var->IsType<LoDTensor>(), true,
platform::errors::InvalidArgument(
"The %zu-th output of Output(%s) must be LoDTensor.", j, out));
auto& in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod());
// TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out Tensor?
#ifdef PADDLE_WITH_MKLDNN
// Fix me: ugly workaround below
// Correct solution:
// set_layout() should NOT be called here (i.e. ShareLoD). Instead,
// layout of output tensor should be set "manually" in Compute()
// of each OPKernel. The reason layout should NOT be shared between
// input and output "automatically" (now by InferShape()->ShareLoD())
// is that layout transform may occur after InferShape().
// Workaround:
// Skip set_layout() when input layout is kMKLDNN
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
// in Compute()
if (in_tensor.layout() != DataLayout::kMKLDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
}
void SetLoDLevel(const std::string& out, int32_t lod_level,
size_t j = 0) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
}
bool IsRuntime() const override { return true; }
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) override {
const std::vector<Variable*>& vars = InputVars(name);
std::vector<InferShapeVarPtr> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) override {
const std::vector<Variable*>& vars = OutputVars(name);
std::vector<InferShapeVarPtr> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
DDim GetInputDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(), 1UL,
platform::errors::InvalidArgument(
"Input(%s) should hold one element, but now it holds %zu elements.",
name, vars.size()));
return this->GetDim(vars[0]);
}
std::vector<DDim> GetInputsDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
return GetDims(vars);
}
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override {
return GetVarTypes(InputVars(name));
}
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override {
return GetVarTypes(OutputVars(name));
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
auto& vars = OutputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(), 1UL,
platform::errors::InvalidArgument("Output(%s) should hold one element, "
"but now it holds %zu elements.",
name, vars.size()));
SetDim(vars[0], dim);
}
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override {
auto& vars = OutputVars(name);
SetDims(vars, dims);
}
protected:
DDim GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input variable is nullptr."));
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only LoDTensor or SelectedRows support 'GetDim', but input "
"Variable's type is %s.",
ToTypeName(var->Type())));
}
}
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const {
std::vector<DDim> ret;
ret.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(ret),
[this](Variable* var) { return this->GetDim(var); });
return ret;
}
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetRepeatedDims method only ban be used in compile time."));
}
void SetDim(Variable* var, const DDim& dim) {
if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Variable type error, expect LoDTensor or SelectedRows, but received "
"(%s).",
ToTypeName(var->Type())));
}
}
void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims) {
size_t length = vars.size();
PADDLE_ENFORCE_EQ(length, dims.size(),
platform::errors::InvalidArgument(
"The number of input variables do not match the "
"number of input dimensions, the number of variables "
"is %zu, the number of dimensions is %zu.",
length, dims.size()));
for (size_t i = 0; i < length; ++i) {
if (vars[i] == nullptr) {
continue;
}
SetDim(vars[i], dims[i]);
}
}
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetRepeatedDims method only can be used in compile time."));
}
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<Variable*>& vars) const {
std::vector<proto::VarType::Type> retv;
retv.resize(vars.size());
std::transform(vars.begin(), vars.end(), retv.begin(),
std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType),
this, std::placeholders::_1));
return retv;
}
proto::VarType::Type GetVarType(Variable* var) const {
return ToVarType(var->Type());
}
private:
const std::vector<Variable*>& InputVars(const std::string& name) const {
auto it = ctx_.inputs.find(name);
PADDLE_ENFORCE_NE(
it, ctx_.inputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the input (%s).", op_.Type(), name));
return it->second;
}
const std::vector<Variable*>& OutputVars(const std::string& name) const {
auto it = ctx_.outputs.find(name);
PADDLE_ENFORCE_NE(
it, ctx_.outputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the outputs (%s).", op_.Type(), name));
return it->second;
}
const OperatorBase& op_;
const RuntimeContext& ctx_;
};
*/
}
// namespace framework
}
// namespace paddle
paddle/fluid/imperative/CMakeLists.txt
浏览文件 @
7f5e5ce9
...
...
@@ -28,4 +28,6 @@ endif(NOT WIN32)
cc_library
(
gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function
)
cc_binary
(
tracer_test SRCS tracer_test.cc DEPS tracer layer op_registry python pybind
${
GLOB_OP_LIB
}
${
GLOB_OPERATOR_DEPS
}
profiler
)
add_subdirectory
(
tests
)
paddle/fluid/operators/fill_constant_op.cc
浏览文件 @
7f5e5ce9
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fill_constant_op.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
#include <iostream>
namespace
paddle
{
namespace
operators
{
...
...
@@ -23,9 +24,11 @@ class FillConstantOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"FillConstant"
);
auto
&
shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int64_t
>>
(
"shape"
);
if
(
!
ctx
->
HasInput
(
"ShapeTensor"
)
&&
!
ctx
->
HasInputs
(
"ShapeTensorList"
))
{
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
PADDLE_ENFORCE_GE
(
...
...
@@ -36,7 +39,7 @@ class FillConstantOp : public framework::OperatorWithKernel {
i
,
shape
[
i
],
framework
::
make_ddim
(
shape
)));
}
}
if
(
shape
.
empty
()
&&
ctx
->
HasInput
(
"ShapeTensor"
))
{
auto
shape_dims
=
ctx
->
GetInputDim
(
"ShapeTensor"
);
int
num_ele
=
1
;
...
...
@@ -48,7 +51,9 @@ class FillConstantOp : public framework::OperatorWithKernel {
return
;
}
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
shape
));
}
protected:
...
...
paddle/fluid/operators/reduce_ops/reduce_op.h
浏览文件 @
7f5e5ce9
...
...
@@ -559,12 +559,16 @@ class ReduceGradOp : public framework::OperatorWithKernel {
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
//std::cerr << "core here" << std::endl;
int
in_dtype
=
ctx
.
Attr
<
int
>
(
"in_dtype"
);
auto
input_data_type
=
(
in_dtype
>=
0
)
?
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
in_dtype
)
:
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Out"
));
//std::cerr << "sum 1" << std::endl;
/*
#ifdef PADDLE_WITH_MKLDNN
auto CanMKLDNNReduceGradBeUsed = [&]() {
auto dx_dims = ctx.Input<Tensor>("X")->dims();
...
...
@@ -580,7 +584,8 @@ class ReduceGradOp : public framework::OperatorWithKernel {
framework::LibraryType::kMKLDNN);
}
#endif
*/
//std::cerr << "sum 2" << std::endl;
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
};
...
...
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
浏览文件 @
7f5e5ce9
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include <string>
#include <iostream>
namespace
paddle
{
namespace
framework
{
...
...
@@ -51,6 +52,7 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
std
::
cerr
<<
"get exec"
<<
std
::
endl
;
int
in_dtype
=
ctx
.
Attr
<
int
>
(
"in_dtype"
);
if
(
in_dtype
>=
0
)
{
return
framework
::
OpKernelType
(
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
7f5e5ce9
...
...
@@ -269,9 +269,12 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
//std::cerr << "softmax here" << std::endl;
auto
res
=
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Loss"
)),
ctx
.
device_context
());
//std::cerr << "softmax end" << std::endl;
return
res
;
}
};
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.h
浏览文件 @
7f5e5ce9
...
...
@@ -28,6 +28,7 @@ template <typename T>
class
SoftmaxWithCrossEntropyKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
context
.
GetPlace
()),
true
,
platform
::
errors
::
Unimplemented
(
"This kernel only runs on CPU."
));
...
...
@@ -106,20 +107,23 @@ template <typename T>
class
SoftmaxWithCrossEntropyGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
Tensor
*
out_grad
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
));
const
Tensor
*
labels
=
context
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
logit_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Logits"
));
const
Tensor
*
softmax
=
context
.
Input
<
Tensor
>
(
"Softmax"
);
const
bool
use_softmax
=
context
.
Attr
<
bool
>
(
"use_softmax"
);
if
(
logit_grad
!=
softmax
||
!
use_softmax
)
{
framework
::
TensorCopy
(
*
softmax
,
context
.
GetPlace
(),
context
.
device_context
(),
logit_grad
);
}
const
bool
soft_label
=
context
.
Attr
<
bool
>
(
"soft_label"
);
auto
ignore_index
=
context
.
Attr
<
int
>
(
"ignore_index"
);
...
...
@@ -133,7 +137,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
logit_grad_2d
.
ShareDataWith
(
*
logit_grad
).
Resize
({
n
,
d
});
labels_2d
.
ShareDataWith
(
*
labels
).
Resize
({
n
,
labels
->
numel
()
/
n
});
out_grad_2d
.
ShareDataWith
(
*
out_grad
).
Resize
({
n
,
d
/
axis_dim
});
auto
out_grad_mat
=
framework
::
EigenMatrix
<
T
>::
From
(
out_grad_2d
);
auto
logit_grad_mat
=
framework
::
EigenMatrix
<
T
>::
From
(
logit_grad_2d
);
auto
&
place
=
*
context
.
template
device_context
<
platform
::
CPUDeviceContext
>()
...
...
@@ -180,7 +184,7 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
}
return
;
}
// for use_softmax=False, continue
if
(
soft_label
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录