Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
7f5e5ce9
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7f5e5ce9
编写于
5月 26, 2021
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
first test version
上级
eeca9639
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
463 addition
and
9 deletion
+463
-9
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+3
-0
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+3
-0
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+427
-0
paddle/fluid/imperative/CMakeLists.txt
paddle/fluid/imperative/CMakeLists.txt
+2
-0
paddle/fluid/operators/fill_constant_op.cc
paddle/fluid/operators/fill_constant_op.cc
+6
-1
paddle/fluid/operators/reduce_ops/reduce_op.h
paddle/fluid/operators/reduce_ops/reduce_op.h
+7
-2
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
+2
-0
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
+4
-1
paddle/fluid/operators/softmax_with_cross_entropy_op.h
paddle/fluid/operators/softmax_with_cross_entropy_op.h
+9
-5
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
7f5e5ce9
...
...
@@ -383,6 +383,9 @@ cc_library(op_meta_info SRCS ../extension/src/ext_op_meta_info.cc DEPS custom_te
cc_library
(
custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper custom_tensor op_meta_info
)
cc_test
(
custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog
)
#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} )
cc_binary
(
new_executor SRCS new_exec.cc DEPS operator op_registry executor
${
GLOB_OP_LIB
}
${
GLOB_OPERATOR_DEPS
}
profiler
)
set
(
FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator
)
cc_library
(
paddle_framework DEPS
${
FLUID_FRAMEWORK_MODULES
}
)
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
7f5e5ce9
...
...
@@ -1026,6 +1026,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
const
RuntimeContext
&
ctx_
;
};
static
void
CheckTensorNANOrInf
(
const
std
::
string
&
op_type
,
const
std
::
string
&
name
,
const
framework
::
Tensor
&
tensor
)
{
...
...
@@ -1598,7 +1599,9 @@ proto::VarType::Type OperatorWithKernel::IndicateVarDataType(
proto
::
VarType
::
Type
dafault_data_type
=
static_cast
<
proto
::
VarType
::
Type
>
(
-
1
);
proto
::
VarType
::
Type
data_type
=
dafault_data_type
;
//std::cerr << "par in" << std::endl;
ParseInputDataType
(
ctx
,
name
,
&
data_type
);
//
PADDLE_ENFORCE_NE
(
data_type
,
dafault_data_type
,
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/framework/operator.h
浏览文件 @
7f5e5ce9
...
...
@@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -572,5 +573,431 @@ class OperatorWithKernel : public OperatorBase {
extern
bool
OpSupportGPU
(
const
std
::
string
&
op_type
);
/*
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx)
: op_(op), ctx_(ctx) {}
bool HasInput(const std::string& name) const override {
// has only one input
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end()) {
return false;
}
const auto& in = it->second;
if (in.size() == 0) return false;
PADDLE_ENFORCE_EQ(
in.size(), 1UL,
platform::errors::InvalidArgument(
"Input %s should not contain more than one inputs.", name));
return in[0] != nullptr;
}
bool HasOutput(const std::string& name) const override {
// has only one output
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end()) {
return false;
}
const auto& out = it->second;
if (out.size() == 0) {
return false;
}
PADDLE_ENFORCE_EQ(
out.size(), 1UL,
platform::errors::InvalidArgument(
"Output %s should not contain more than one outputs.", name));
return out[0] != nullptr;
}
bool HasInputs(const std::string& name) const override {
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
if (it == ins.end() || it->second.empty()) {
return false;
}
for (auto& input : it->second) {
if (input == nullptr) {
return false;
}
}
return true;
}
bool HasOutputs(const std::string& name) const override {
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
if (it == outs.end() || it->second.empty()) {
return false;
}
for (auto& output : it->second) {
if (output == nullptr) {
return false;
}
}
return true;
}
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
std::vector<std::string> Inputs(const std::string& name) const override {
return op_.Inputs(name);
}
std::vector<std::string> Outputs(const std::string& name) const override {
return op_.Outputs(name);
}
std::string GetInputNameByIdx(size_t idx) const override {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of inputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(), idx, op_proto->inputs().size()));
return op_proto->inputs()[idx].name();
}
std::string GetOutputNameByIdx(size_t idx) const override {
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
PADDLE_ENFORCE_LT(
idx, op_proto->outputs().size(),
platform::errors::OutOfRange(
"The index should be less than the size of outputs of "
"operator %s, but got index is %d and size is %d",
op_.Type(), idx, op_proto->outputs().size()));
return op_proto->outputs()[idx].name();
}
void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(
in_it, ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
out_it, ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i, in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(), i));
PADDLE_ENFORCE_LT(j, out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(), j));
Variable* in_var = in_it->second[i];
Variable* out_var = out_it->second[j];
PADDLE_ENFORCE_EQ(
in_var->Type(), out_var->Type(),
platform::errors::InvalidArgument(
"The type of input (%s) and output (%s) are inconsistent.", in,
out));
if (in_var->IsType<framework::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
out_sele_rows->set_rows(in_sele_rows.rows());
out_sele_rows->set_height(in_sele_rows.height());
} else if (in_var->IsType<framework::LoDTensor>()) {
auto& in_lod_tensor = in_var->Get<framework::LoDTensor>();
auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
out_lod_tensor->Resize(in_lod_tensor.dims());
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, the input type of ShareDim only can be LoDTensor "
"or SelectedRows."));
}
}
void ShareAllLoD(const std::string& in,
const std::string& out) const override {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(in_it, ctx_.inputs.end(),
platform::errors::NotFound(
"Input [%s] found error in Op [%s]", in, op_.Type()));
PADDLE_ENFORCE_NE(
out_it, ctx_.outputs.end(),
platform::errors::NotFound("Output [%s] found error in Op [%s]", out,
op_.Type()));
auto& in_var_list = in_it->second;
auto& out_var_list = out_it->second;
PADDLE_ENFORCE_EQ(
in_var_list.size(), out_var_list.size(),
platform::errors::PreconditionNotMet(
"Op [%s]: Input var size should be equal with output var size",
op_.Type()));
auto& out_var_names = op_.Outputs(out);
for (size_t i = 0; i < in_var_list.size(); ++i) {
if (out_var_names[i] == framework::kEmptyVarName) {
continue;
}
Variable* in_var = in_var_list[i];
if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = out_var_list[i];
PADDLE_ENFORCE_EQ(out_var->IsType<LoDTensor>(), true,
platform::errors::PreconditionNotMet(
"The %d-th output of Output(%s) must be LoDTensor.",
i, out_var_names[i]));
auto& in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod());
#ifdef PADDLE_WITH_MKLDNN
if (in_tensor.layout() != DataLayout::kMKLDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override {
auto in_it = ctx_.inputs.find(in);
auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_NE(
in_it, ctx_.inputs.end(),
platform::errors::NotFound("Input %s does not exist.", in));
PADDLE_ENFORCE_NE(
out_it, ctx_.outputs.end(),
platform::errors::NotFound("Output %s does not exist.", out));
PADDLE_ENFORCE_LT(i, in_it->second.size(),
platform::errors::InvalidArgument(
"The index of input dimension is out of range, "
"excepted index less than %zu, but received %zu.",
in_it->second.size(), i));
PADDLE_ENFORCE_LT(j, out_it->second.size(),
platform::errors::InvalidArgument(
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
out_it->second.size(), j));
Variable* in_var = in_it->second.at(i);
if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = out_it->second.at(j);
PADDLE_ENFORCE_EQ(
out_var->IsType<LoDTensor>(), true,
platform::errors::InvalidArgument(
"The %zu-th output of Output(%s) must be LoDTensor.", j, out));
auto& in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod());
// TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out Tensor?
#ifdef PADDLE_WITH_MKLDNN
// Fix me: ugly workaround below
// Correct solution:
// set_layout() should NOT be called here (i.e. ShareLoD). Instead,
// layout of output tensor should be set "manually" in Compute()
// of each OPKernel. The reason layout should NOT be shared between
// input and output "automatically" (now by InferShape()->ShareLoD())
// is that layout transform may occur after InferShape().
// Workaround:
// Skip set_layout() when input layout is kMKLDNN
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
// in Compute()
if (in_tensor.layout() != DataLayout::kMKLDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
}
void SetLoDLevel(const std::string& out, int32_t lod_level,
size_t j = 0) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
}
bool IsRuntime() const override { return true; }
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) override {
const std::vector<Variable*>& vars = InputVars(name);
std::vector<InferShapeVarPtr> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) override {
const std::vector<Variable*>& vars = OutputVars(name);
std::vector<InferShapeVarPtr> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
DDim GetInputDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(), 1UL,
platform::errors::InvalidArgument(
"Input(%s) should hold one element, but now it holds %zu elements.",
name, vars.size()));
return this->GetDim(vars[0]);
}
std::vector<DDim> GetInputsDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
return GetDims(vars);
}
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override {
return GetVarTypes(InputVars(name));
}
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override {
return GetVarTypes(OutputVars(name));
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
auto& vars = OutputVars(name);
PADDLE_ENFORCE_EQ(
vars.size(), 1UL,
platform::errors::InvalidArgument("Output(%s) should hold one element, "
"but now it holds %zu elements.",
name, vars.size()));
SetDim(vars[0], dim);
}
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override {
auto& vars = OutputVars(name);
SetDims(vars, dims);
}
protected:
DDim GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::InvalidArgument("Input variable is nullptr."));
if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims();
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only LoDTensor or SelectedRows support 'GetDim', but input "
"Variable's type is %s.",
ToTypeName(var->Type())));
}
}
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const {
std::vector<DDim> ret;
ret.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(ret),
[this](Variable* var) { return this->GetDim(var); });
return ret;
}
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetRepeatedDims method only ban be used in compile time."));
}
void SetDim(Variable* var, const DDim& dim) {
if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Variable type error, expect LoDTensor or SelectedRows, but received "
"(%s).",
ToTypeName(var->Type())));
}
}
void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims) {
size_t length = vars.size();
PADDLE_ENFORCE_EQ(length, dims.size(),
platform::errors::InvalidArgument(
"The number of input variables do not match the "
"number of input dimensions, the number of variables "
"is %zu, the number of dimensions is %zu.",
length, dims.size()));
for (size_t i = 0; i < length; ++i) {
if (vars[i] == nullptr) {
continue;
}
SetDim(vars[i], dims[i]);
}
}
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetRepeatedDims method only can be used in compile time."));
}
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<Variable*>& vars) const {
std::vector<proto::VarType::Type> retv;
retv.resize(vars.size());
std::transform(vars.begin(), vars.end(), retv.begin(),
std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType),
this, std::placeholders::_1));
return retv;
}
proto::VarType::Type GetVarType(Variable* var) const {
return ToVarType(var->Type());
}
private:
const std::vector<Variable*>& InputVars(const std::string& name) const {
auto it = ctx_.inputs.find(name);
PADDLE_ENFORCE_NE(
it, ctx_.inputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the input (%s).", op_.Type(), name));
return it->second;
}
const std::vector<Variable*>& OutputVars(const std::string& name) const {
auto it = ctx_.outputs.find(name);
PADDLE_ENFORCE_NE(
it, ctx_.outputs.end(),
platform::errors::NotFound(
"Operator (%s) does not have the outputs (%s).", op_.Type(), name));
return it->second;
}
const OperatorBase& op_;
const RuntimeContext& ctx_;
};
*/
}
// namespace framework
}
// namespace paddle
paddle/fluid/imperative/CMakeLists.txt
浏览文件 @
7f5e5ce9
...
...
@@ -28,4 +28,6 @@ endif(NOT WIN32)
cc_library
(
gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows selected_rows_functor var_type_traits layer math_function
)
cc_binary
(
tracer_test SRCS tracer_test.cc DEPS tracer layer op_registry python pybind
${
GLOB_OP_LIB
}
${
GLOB_OPERATOR_DEPS
}
profiler
)
add_subdirectory
(
tests
)
paddle/fluid/operators/fill_constant_op.cc
浏览文件 @
7f5e5ce9
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fill_constant_op.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
#include <iostream>
namespace
paddle
{
namespace
operators
{
...
...
@@ -23,9 +24,11 @@ class FillConstantOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"FillConstant"
);
auto
&
shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int64_t
>>
(
"shape"
);
if
(
!
ctx
->
HasInput
(
"ShapeTensor"
)
&&
!
ctx
->
HasInputs
(
"ShapeTensorList"
))
{
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
PADDLE_ENFORCE_GE
(
...
...
@@ -48,7 +51,9 @@ class FillConstantOp : public framework::OperatorWithKernel {
return
;
}
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
shape
));
}
protected:
...
...
paddle/fluid/operators/reduce_ops/reduce_op.h
浏览文件 @
7f5e5ce9
...
...
@@ -559,12 +559,16 @@ class ReduceGradOp : public framework::OperatorWithKernel {
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
//std::cerr << "core here" << std::endl;
int
in_dtype
=
ctx
.
Attr
<
int
>
(
"in_dtype"
);
auto
input_data_type
=
(
in_dtype
>=
0
)
?
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
in_dtype
)
:
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Out"
));
//std::cerr << "sum 1" << std::endl;
/*
#ifdef PADDLE_WITH_MKLDNN
auto CanMKLDNNReduceGradBeUsed = [&]() {
auto dx_dims = ctx.Input<Tensor>("X")->dims();
...
...
@@ -580,7 +584,8 @@ class ReduceGradOp : public framework::OperatorWithKernel {
framework::LibraryType::kMKLDNN);
}
#endif
*/
//std::cerr << "sum 2" << std::endl;
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
};
...
...
paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
浏览文件 @
7f5e5ce9
...
...
@@ -15,6 +15,7 @@
#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h"
#include <string>
#include <iostream>
namespace
paddle
{
namespace
framework
{
...
...
@@ -51,6 +52,7 @@ class ReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
std
::
cerr
<<
"get exec"
<<
std
::
endl
;
int
in_dtype
=
ctx
.
Attr
<
int
>
(
"in_dtype"
);
if
(
in_dtype
>=
0
)
{
return
framework
::
OpKernelType
(
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.cc
浏览文件 @
7f5e5ce9
...
...
@@ -269,9 +269,12 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
//std::cerr << "softmax here" << std::endl;
auto
res
=
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Loss"
)),
ctx
.
device_context
());
//std::cerr << "softmax end" << std::endl;
return
res
;
}
};
...
...
paddle/fluid/operators/softmax_with_cross_entropy_op.h
浏览文件 @
7f5e5ce9
...
...
@@ -28,6 +28,7 @@ template <typename T>
class
SoftmaxWithCrossEntropyKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
context
.
GetPlace
()),
true
,
platform
::
errors
::
Unimplemented
(
"This kernel only runs on CPU."
));
...
...
@@ -106,13 +107,16 @@ template <typename T>
class
SoftmaxWithCrossEntropyGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
Tensor
*
out_grad
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Loss"
));
const
Tensor
*
labels
=
context
.
Input
<
Tensor
>
(
"Label"
);
Tensor
*
logit_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Logits"
));
const
Tensor
*
softmax
=
context
.
Input
<
Tensor
>
(
"Softmax"
);
const
bool
use_softmax
=
context
.
Attr
<
bool
>
(
"use_softmax"
);
if
(
logit_grad
!=
softmax
||
!
use_softmax
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录