提交 64464cb1 编写于 作者: S sneaxiy

Merge develop

...@@ -142,8 +142,6 @@ class Graph { ...@@ -142,8 +142,6 @@ class Graph {
nodes_.erase(node); nodes_.erase(node);
} }
const ProgramDesc &program() const { return program_; }
private: private:
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {
...@@ -154,7 +152,7 @@ class Graph { ...@@ -154,7 +152,7 @@ class Graph {
} }
// NOTE: program_ shouldn't be exposed to user. // NOTE: program_ shouldn't be exposed to user.
const ProgramDesc &program_; const ProgramDesc program_;
std::map<std::string, boost::any> attrs_; std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_; std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
......
...@@ -41,8 +41,7 @@ class Node { ...@@ -41,8 +41,7 @@ class Node {
explicit Node(OpDesc* op_desc) explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()), : name_(op_desc->Type()),
var_desc_(nullptr), var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc)), // TODO(panyx0718) the pointer in the op_desc_(new OpDesc(*op_desc, op_desc->Block())),
// original OpDesc might go out.
type_(Type::kOperation) {} type_(Type::kOperation) {}
Type NodeType() const { return type_; } Type NodeType() const { return type_; }
......
...@@ -129,6 +129,10 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -129,6 +129,10 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
"Optimized for variable") "Optimized for variable")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<std::string>>(OpCreationCallstackAttrName(),
"Callstack for Op Creatation.")
.SetDefault({});
Validate(); Validate();
} }
......
...@@ -39,6 +39,7 @@ class OpProtoAndCheckerMaker { ...@@ -39,6 +39,7 @@ class OpProtoAndCheckerMaker {
public: public:
static const char *OpRoleAttrName() { return "op_role"; } static const char *OpRoleAttrName() { return "op_role"; }
static const char *OpRoleVarAttrName() { return "op_role_var"; } static const char *OpRoleVarAttrName() { return "op_role_var"; }
static const char *OpCreationCallstackAttrName() { return "op_callstack"; }
void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker); void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker);
......
...@@ -11,15 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,15 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <gflags/gflags.h> #include "paddle/fluid/framework/operator.h"
#include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <sstream>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/data_transform.h" #include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -127,19 +129,48 @@ static LoD GetLoD(const Scope& scope, const std::string& name) { ...@@ -127,19 +129,48 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
} }
void OperatorBase::Run(const Scope& scope, const platform::Place& place) { void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
VLOG(4) << place << " " << DebugStringEx(&scope); try {
if (platform::is_gpu_place(place)) { if (VLOG_IS_ON(4)) {
VLOG(4) << place << " " << DebugStringEx(&scope);
}
if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA #ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot run operator on place %s", place); PADDLE_THROW("Cannot run operator on place %s", place);
#else #else
auto dev_id = boost::get<platform::CUDAPlace>(place).device; auto dev_id = boost::get<platform::CUDAPlace>(place).device;
platform::SetDeviceId(dev_id); platform::SetDeviceId(dev_id);
#endif #endif
}
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
RunImpl(scope, place);
if (VLOG_IS_ON(3)) {
VLOG(3) << place << " " << DebugStringEx(&scope);
}
} catch (platform::EnforceNotMet exception) {
if (Attrs().count("sub_block") != 0) {
throw exception;
}
auto& callstack = Attr<std::vector<std::string>>(
OpProtoAndCheckerMaker::OpCreationCallstackAttrName());
if (callstack.empty()) {
throw exception;
}
std::ostringstream sout;
sout << "Invoke operator " << Type() << " error.\n";
sout << "Python Callstacks: \n";
for (auto& line : callstack) {
sout << line;
}
sout << "C++ Callstacks: \n";
sout << exception.err_str_;
exception.err_str_ = sout.str();
throw exception;
} catch (...) {
std::rethrow_exception(std::current_exception());
} }
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
RunImpl(scope, place);
VLOG(3) << place << " " << DebugStringEx(&scope);
} }
bool OperatorBase::HasInputs(const std::string& name) const { bool OperatorBase::HasInputs(const std::string& name) const {
...@@ -167,7 +198,7 @@ const std::vector<std::string>& OperatorBase::Inputs( ...@@ -167,7 +198,7 @@ const std::vector<std::string>& OperatorBase::Inputs(
} }
bool OperatorBase::HasOutputs(const std::string& name) const { bool OperatorBase::HasOutputs(const std::string& name) const {
if (outputs_.find(name) != outputs_.end()) { if (outputs_.end() != outputs_.find(name)) {
return true; return true;
} else { } else {
return false; return false;
......
...@@ -62,9 +62,21 @@ class ConcatGradKernel : public framework::OpKernel<T> { ...@@ -62,9 +62,21 @@ class ConcatGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out_grad = auto* out_grad =
ctx.Input<framework::Tensor>(framework::GradVarName("Out")); ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::LoDTensor>("X");
auto out_var_names = ctx.Outputs(framework::GradVarName("X")); auto out_var_names = ctx.Outputs(framework::GradVarName("X"));
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); auto outs =
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
{
auto dx = outs;
auto x = ins;
for (size_t i = 0; i < dx.size(); ++i) {
if (dx[i] != nullptr) {
dx[i]->set_lod(x[i]->lod());
}
}
}
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis")); int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
// get output tensor that the name is not kEmptyVarName // get output tensor that the name is not kEmptyVarName
......
...@@ -137,9 +137,10 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -137,9 +137,10 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
}; };
template <typename T> template <typename T>
class EltwiseAddMKLDNNGradKernel : public framework::OpKernel<T> { class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
...@@ -136,9 +137,11 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, ...@@ -136,9 +137,11 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddGradKernel : public framework::OpKernel<T> { class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -53,9 +53,10 @@ struct DivGradDY { ...@@ -53,9 +53,10 @@ struct DivGradDY {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseDivGradKernel : public framework::OpKernel<T> { class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle { namespace paddle {
...@@ -55,9 +56,10 @@ struct MaxGradDy { ...@@ -55,9 +56,10 @@ struct MaxGradDy {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseMaxGradKernel : public framework::OpKernel<T> { class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -55,9 +55,10 @@ struct MinGradDy { ...@@ -55,9 +55,10 @@ struct MinGradDy {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseMinGradKernel : public framework::OpKernel<T> { class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
...@@ -84,9 +85,10 @@ struct MulGradDY { ...@@ -84,9 +85,10 @@ struct MulGradDY {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseMulGradKernel : public framework::OpKernel<T> { class ElementwiseMulGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
......
...@@ -205,6 +205,20 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad { ...@@ -205,6 +205,20 @@ class ElementwiseOpExplicitGrad : public ElementwiseOpGrad {
} }
}; };
template <typename T>
class ElemwiseGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dx =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (dx != nullptr) {
auto& dout =
*context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
dx->set_lod(dout.lod());
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op.h"
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
namespace paddle { namespace paddle {
...@@ -50,9 +51,10 @@ struct SubGradDY { ...@@ -50,9 +51,10 @@ struct SubGradDY {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseSubGradKernel : public framework::OpKernel<T> { class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -71,7 +71,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> { ...@@ -71,7 +71,7 @@ class ConcatGradFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs, const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) { const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
size_t num = outputs->size(); size_t num = outputs->size();
......
...@@ -189,7 +189,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> { ...@@ -189,7 +189,7 @@ class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs, const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) { const int axis, std::vector<framework::Tensor*>* outputs) {
// TODO(zcd): Add input data validity checking // TODO(zcd): Add input data validity checking
int o_num = outputs->size(); int o_num = outputs->size();
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -57,7 +57,7 @@ template <typename DeviceContext, typename T> ...@@ -57,7 +57,7 @@ template <typename DeviceContext, typename T>
class ConcatGradFunctor { class ConcatGradFunctor {
public: public:
void operator()(const DeviceContext& context, const framework::Tensor& input, void operator()(const DeviceContext& context, const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs, const std::vector<const framework::LoDTensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs); const int axis, std::vector<framework::Tensor*>* outputs);
}; };
......
...@@ -62,23 +62,31 @@ class MulGradKernel : public framework::OpKernel<T> { ...@@ -62,23 +62,31 @@ class MulGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims"); int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims"); int y_num_col_dims = ctx.template Attr<int>("y_num_col_dims");
const Tensor* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<framework::LoDTensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<framework::LoDTensor>("Y");
const Tensor x_matrix = x->dims().size() > 2 auto x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, x_num_col_dims) ? framework::ReshapeToMatrix(*x, x_num_col_dims)
: *x; : static_cast<const Tensor&>(*x);
const Tensor y_matrix = y->dims().size() > 2 auto y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, y_num_col_dims) ? framework::ReshapeToMatrix(*y, y_num_col_dims)
: *y; : static_cast<const Tensor&>(*y);
const Tensor* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
Tensor dout_mat; Tensor dout_mat;
dout_mat.ShareDataWith(*dout); dout_mat.ShareDataWith(*dout);
dout_mat.Resize({framework::flatten_to_2d(x->dims(), x_num_col_dims)[0], dout_mat.Resize({framework::flatten_to_2d(x->dims(), x_num_col_dims)[0],
framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]}); framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]});
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y")); auto* dy = ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
if (dx != nullptr) {
dx->set_lod(x->lod());
}
if (dy != nullptr) {
dy->set_lod(y->lod());
}
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
if (dx) { if (dx) {
......
...@@ -68,7 +68,9 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> { ...@@ -68,7 +68,9 @@ class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out")); auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X")); auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (x_grad) {
x_grad->set_lod(x->lod());
}
auto lod = x->lod(); auto lod = x->lod();
const size_t level = lod.size() - 1; const size_t level = lod.size() - 1;
......
...@@ -66,6 +66,9 @@ class SequenceSoftmaxGradKernel : public framework::OpKernel<T> { ...@@ -66,6 +66,9 @@ class SequenceSoftmaxGradKernel : public framework::OpKernel<T> {
auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out")); auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X")); auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (x_grad) {
x_grad->set_lod(x->lod());
}
auto lod = x->lod(); auto lod = x->lod();
const size_t level = lod.size() - 1; const size_t level = lod.size() - 1;
......
...@@ -30,6 +30,8 @@ class TopkOp : public framework::OperatorWithKernel { ...@@ -30,6 +30,8 @@ class TopkOp : public framework::OperatorWithKernel {
"Output(Indices) of TopkOp should not be null."); "Output(Indices) of TopkOp should not be null.");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(input_dims.size(), 2,
"Rank of TopK op's input must be 2.");
const int k = static_cast<int>(ctx->Attrs().Get<int>("k")); const int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
PADDLE_ENFORCE_GE(k, 1, "k must >= 1"); PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
......
...@@ -40,6 +40,9 @@ void BindConstValue(pybind11::module* m) { ...@@ -40,6 +40,9 @@ void BindConstValue(pybind11::module* m) {
op_proto_and_checker_maker.def( op_proto_and_checker_maker.def(
"kOpRoleVarAttrName", "kOpRoleVarAttrName",
framework::OpProtoAndCheckerMaker::OpRoleVarAttrName); framework::OpProtoAndCheckerMaker::OpRoleVarAttrName);
op_proto_and_checker_maker.def(
"kOpCreationCallstackAttrName",
framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName);
} }
} // namespace pybind } // namespace pybind
......
...@@ -18,6 +18,7 @@ import collections ...@@ -18,6 +18,7 @@ import collections
import contextlib import contextlib
import re import re
import six import six
import traceback
import numpy as np import numpy as np
...@@ -499,6 +500,10 @@ class Operator(object): ...@@ -499,6 +500,10 @@ class Operator(object):
if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0: if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0:
del op_attrs[role_var_name] del op_attrs[role_var_name]
callstack_var_name = op_maker.kOpCreationCallstackAttrName()
op_attrs[callstack_var_name] = list(
reversed(traceback.format_stack()))[1:]
if len(self.desc.type()) != 0: if len(self.desc.type()) != 0:
return return
if type is None: if type is None:
......
...@@ -67,7 +67,10 @@ class TestOperator(unittest.TestCase): ...@@ -67,7 +67,10 @@ class TestOperator(unittest.TestCase):
self.assertEqual(mul_op.output("Out"), ["mul.out"]) self.assertEqual(mul_op.output("Out"), ["mul.out"])
self.assertEqual( self.assertEqual(
set(mul_op.attr_names), set(mul_op.attr_names),
set(["x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var"])) set([
"x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var",
"op_callstack"
]))
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("x_num_col_dims"), 1) self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册