提交 5a4367bb 编写于 作者: Y Yang Yu

Update

上级 63e31507
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,12 +27,16 @@ class ActivationKernel ...@@ -26,12 +27,16 @@ class ActivationKernel
using T = typename Functor::ELEMENT_TYPE; using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto& X = detail::Ref(context.Input<framework::Tensor>("X"),
auto* Out = context.Output<framework::Tensor>("Out"); "Cannot get input tensor X, variable name = %s",
Out->mutable_data<T>(context.GetPlace()); context.op().Input("X"));
auto x = framework::EigenVector<T>::Flatten(*X); auto& Out = detail::Ref(context.Output<framework::Tensor>("Out"),
auto out = framework::EigenVector<T>::Flatten(*Out); "Cannot get output tensor Out, variable name = %s",
context.op().Output("Out"));
Out.mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(X);
auto out = framework::EigenVector<T>::Flatten(Out);
auto* place = auto* place =
context.template device_context<DeviceContext>().eigen_device(); context.template device_context<DeviceContext>().eigen_device();
Functor functor; Functor functor;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册