具有资源句柄的TensorFlow自定义C ++ op

Python代码:


import os
import sys
from subprocess import check_call
import tensorflow as tf


CC_NAME = "tf-resource-op.cc"
SO_NAME = "tf-resource-op.so"


def compile_so():
  use_cxx11_abi = hasattr(tf,'CXX11_ABI_flaG') and tf.CXX11_ABI_flaG
  common_opts = ["-shared","-O2"]
  common_opts += ["-std=c++11"]
  if sys.platform == "darwin":
    common_opts += ["-undefined","dynamic_lookup"]
  tf_include = tf.sysconfig.get_include()  # e.g. "...python2.7/site-packages/tensorflow/include"
  tf_include_nsync = tf_include + "/external/nsync/public"  # https://github.com/tensorflow/tensorflow/issues/2412
  include_paths = [tf_include,tf_include_nsync]
  for include_path in include_paths:
    common_opts += ["-I",include_path]
  common_opts += ["-fPIC","-v"]
  common_opts += ["-D_GLIBCXX_USE_CXX11_ABI=%i" % (1 if use_cxx11_abi else 0)]
  common_opts += ["-g"]
  opts = common_opts + [CC_NAME,"-o",SO_NAME]
  ld_flags = ["-L%s" % tf.sysconfig.get_lib(),"-ltensorflow_framework"]
  opts += ld_flags
  cmd_bin = "g++"
  cmd_args = [cmd_bin] + opts
  print("$ %s" % " ".join(cmd_args))
  check_call(cmd_args)


def main():
  if not os.path.exists(SO_NAME):
    compile_so()
  mod = tf.load_op_library(SO_NAME)
  handle = mod.open_fst_load(filename="foo.bar")
  new_states,scores = mod.open_fst_transition(handle=handle,inputs=[0],states=[0])

  with tf.Session() as session:
    # InternalError: ndarray was 1 bytes but TF_Tensor was 98 bytes
    # print("fst:",session.run(handle))

    out_new_states,out_scores = session.run((new_states,scores))
    print("output new states:",out_new_states)
    print("output scores:",out_scores)

    # When session unloads,crashes with assertion:
    # F .../site-packages/tensorflow/include/tensorflow/core/lib/core/refcount.h:79] Check failed: ref_.load() == 0 (1 vs. 0)  # nopep8


if __name__ == '__main__':
  import better_exchook
  better_exchook.install()
  main()

C ++代码:


#include <exception>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"

using namespace tensorflow;


REGISTER_OP("OpenFstLoad")
.Attr("filename: string")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Output("handle: resource")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape)
.Doc("OpenFstLoad: loads FST,creates TF resource,persistent across runs in the session");


REGISTER_OP("OpenFstTransition")
.Input("handle: resource")
.Input("states: int32")
.Input("inputs: int32")
.Output("new_states: int32")
.Output("scores: float32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
  c->set_output(0,c->input(1));
  c->set_output(1,c->input(1));
  return Status::OK();
})
.Doc("OpenFstTransition: performs a transition");


struct OpenFstInstance : public ResourceBase {
  explicit OpenFstInstance(const string& filename) : filename_(filename) {}

  string DebugString() override {
    return strings::StrCat("OpenFstInstance[",filename_,"]");
  }

  const string filename_;
};


// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/resource_op_kernel.h
// TFUtil.TFArrayContainer
class OpenFstLoadOp : public ResourceOpKernel<OpenFstInstance> {
 public:
  explicit OpenFstLoadOp(OpKernelConstruction* context)
      : ResourceOpKernel(context) {
    OP_REQUIRES_OK(context,context->Getattr("filename",&filename_));
  }

 private:
  virtual bool IsCancellable() const { return false; }
  virtual void Cancel() {}

  Status CreateResource(OpenFstInstance** ret) override EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    try {
      *ret = new OpenFstInstance(filename_);
    } catch (std::exception& exc) {
      return errors::Internal("Could not load OpenFst ",",exception: ",exc.what());
    }
    if(*ret == nullptr)
      return errors::ResourceExhausted("Failed to allocate");
    return Status::OK();
  }

  Status VerifyResource(OpenFstInstance* fst) override {
    if(fst->filename_ != filename_)
      return errors::InvalidArgument("Filename mismatch: expected "," but got ",fst->filename_,".");
    return Status::OK();
  }

  string filename_;
};

REGISTER_KERNEL_BUILDER(Name("OpenFstLoad").Device(DEVICE_CPU),OpenFstLoadOp);


class OpenFstTransitionOp : public OpKernel {
 public:
  using OpKernel::OpKernel;

  void Compute(OpKernelContext* context) override {
    OpenFstInstance* fst;
    OP_REQUIRES_OK(context,GetResourceFromContext(context,"handle",&fst));
    core::ScopedUnref unref(fst);

    const Tensor& states_tensor = context->input(1);
    auto states_flat = states_tensor.flat<int32>();

    const Tensor& inputs_tensor = context->input(2);
    auto inputs_flat = inputs_tensor.flat<int32>();

    OP_REQUIRES(
      context,TensorShapeUtils::IsVector(states_tensor.shape()) &&
      TensorShapeUtils::IsVector(inputs_tensor.shape()) &&
      states_flat.size() == inputs_flat.size(),errors::InvalidArgument(
        "Shape mismatch. states ",states_tensor.shape().DebugString()," vs inputs ",inputs_tensor.shape().DebugString()));

    Tensor* output_new_states_tensor = NULL;
    OP_REQUIRES_OK(context,context->allocate_output(0,states_tensor.shape(),&output_new_states_tensor));
    auto output_new_states_flat = output_new_states_tensor->flat<int32>();
    Tensor* output_scores_tensor = NULL;
    OP_REQUIRES_OK(context,context->allocate_output(1,&output_scores_tensor));
    auto output_scores_flat = output_scores_tensor->flat<float>();

    for(int i = 0; i < inputs_flat.size(); ++i) {
      output_new_states_flat(i) = -1;  // TODO
      output_scores_flat(i) = -1.;  // TODO
    }
  }
};

REGISTER_KERNEL_BUILDER(Name("OpenFstTransition").Device(DEVICE_CPU),OpenFstTransitionOp);

一些问题:

运行print("fst:",session.run(handle))会引发异常InternalError: ndarray was 1 bytes but TF_Tensor was 98 bytes。为什么?什么意思?

当会话卸载时,它会因断言而崩溃: F .../site-packages/tensorflow/include/tensorflow/core/lib/core/refcount.h:79] Check failed: ref_.load() == 0 (1 vs. 0)。 堆栈跟踪:

2   libsystem_c.dylib               0x00007fff6687d1ae abort + 127
3   libtensorflow_framework.so      0x0000000107382e70 tensorflow::internal::LogMessageFatal::~LogMessageFatal() + 32
4   libtensorflow_framework.so      0x0000000107382e80 tensorflow::internal::LogMessageFatal::~LogMessageFatal() + 16
5   tf-resource-op.so               0x0000000128093d82 tensorflow::core::RefCounted::~RefCounted() + 162
6   tf-resource-op.so               0x0000000128095e2e OpenFstInstance::~OpenFstInstance() + 46 (tf-resource-op.cc:40)
7   libtensorflow_framework.so      0x000000010726a1f3 tensorflow::ResourceMgr::DoDelete(std::__1::basic_string<char,std::__1::char_traits<char>,std::__1::allocator<char> > const&,unsigned long long,std::__1::basic_string<char,std::__1::allocator<char> > const&) + 307
8   libtensorflow_framework.so      0x000000010726a433 tensorflow::ResourceMgr::DoDelete(std::__1::basic_string<char,std::__1::type_index,std::__1::allocator<char> > const&) + 99
9   tf-resource-op.so               0x000000012809457b tensorflow::ResourceOpKernel<OpenFstInstance>::~ResourceOpKernel() + 91 (resource_op_kernel.h:60)
10  tf-resource-op.so               0x0000000128094694 OpenFstLoadOp::~OpenFstLoadOp() + 52 (tf-resource-op.cc:53)
11  libtensorflow_framework.so      0x0000000107264d4f tensorflow::OpSegment::Item::~Item() + 63
12  libtensorflow_framework.so      0x000000010726558f tensorflow::OpSegment::RemoveHold(std::__1::basic_string<char,std::__1::allocator<char> > const&) + 303
13  _pywrap_tensorflow_internal.so  0x0000000113b7b712 tensorflow::DirectSession::~DirectSession() + 274
14  _pywrap_tensorflow_internal.so  0x0000000113b7bade tensorflow::DirectSession::~DirectSession() + 14

我猜想OpenFstInstance对象的引用计数有些混乱。但为什么?我该如何解决?

(相关的是this question。)

zl1002007 回答:具有资源句柄的TensorFlow自定义C ++ op

在构建标志中添加from selenium import webdriver url = 'https://www.rad.cvm.gov.br/ENETCONSULTA/frmGerenciaPaginaFRE.aspx?NumeroSequencialDocumento=82594&CodigoTipoInstituicao=2' driver = webdriver.Safari() driver.get(url) iframe = driver.find_element_by_tag_name('iframe') driver.switch_to.frame(iframe) driver.page_source 可以解决此问题。 in TF issue 17316说明了这种解决方法。

本文链接:https://www.f2er.com/3134794.html

大家都在问