TensorFlow函数教程:tf.lite.Interpreter

2019-04-03 14:55 更新

tf.lite.Interpreter函数

Interpreter

别名:>

  • 类 tf.contrib.lite.Interpreter
  • 类 tf.lite.Interpreter

定义在:tensorflow/lite/python/interpreter.py

TF-Lite模型的解释器推理。

__init__
__init__(
    model_path=None,
    model_content=None
)

构造函数。

参数:

  • model_path:TF-Lite Flatbuffer文件的路径。
  • model_content:模型的内容。

可能引发的异常:

  • ValueError:如果解释器无法创建。

方法

allocate_tensors
allocate_tensors()
get_input_details
get_input_details()

获取模型输入详细信息。

返回:

输入详细信息列表。

get_output_details
get_output_details()

获取模型输出详细信息

返回:

输出详细信息列表。

get_tensor
get_tensor(tensor_index)

获取输入张量的值(获取副本)。

如果您想避免复制,请使用tensor()。

参数:

  • tensor_index:得到张量的张量指数。该值可以从get_output_details中的'index'字段获得。

返回:

一个numpy数组。

get_tensor_details
get_tensor_details()

获取具有有效张量详细信息的每个张量的张量详细信息。

未找到有关张量的所需信息的张量不会添加到列表中。这包括没有名字的临时张量。

返回:

包含张量信息的词典列表。

invoke
invoke()

调用解释器。

在调用之前,请务必设置输入大小,分配张量和填充值。

可能引发的异常:

  • ValueError:当底层解释器失败时引发ValueError。
reset_all_variables
reset_all_variables()
resize_tensor_input
resize_tensor_input(
    input_index,
    tensor_size
)

调整输入张量的大小。

参数:

  • input_index:要设置的输入的张量索引。该值可以从get_input_details中的'index'字段获得。
  • tensor_size:tensor_shape调整输入的大小。

可能引发的异常:

  • ValueError:如果解释器无法调整输入张量的大小。
set_tensor
set_tensor(
    tensor_index,
    value
)

设置输入张量的值。请注意,这将复制value的数据。

如果要避免复制,可以使用该tensor()函数获取指向tflite解释器中输入缓冲区的numpy缓冲区。

参数:

  • tensor_index:设置的张量的张量指数。该值可以从get_input_details中的'index'字段获得。
  • value:张量值设置。

可能引发的异常:

  • ValueError:如果解释器无法设置张量。
tensor
tensor(tensor_index)

返回给出当前张量缓冲区的numpy视图的函数。

这允许在没有副本的情况下读写这个张量。这更接近于C ++ Interpreter类接口的tensor()成员,因此得名。小心不要通过调用allocate_tensors()和invoke()来保持这些输出引用。

用法:

interpreter.allocate_tensors()
input = interpreter.tensor(interpreter.get_input_details()[0]["index"])
output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
for i in range(10):
  input().fill(3.)
  interpreter.invoke()
  print("inference %s" % output())

注意这个函数如何避免直接生成numpy数组。将实际numpy视图保持到数据的时间不能超过必要的时间是很重要的。如果你这样做了,则无法再调用解释器,因为解释器可能会调整大小并使引用的张量无效。NumPy API不允许底层缓冲区的任何可变性。

错误:

input = interpreter.tensor(interpreter.get_input_details()[0]["index"])()
output = interpreter.tensor(interpreter.get_output_details()[0]["index"])()
interpreter.allocate_tensors()  # This will throw RuntimeError
for i in range(10):
  input.fill(3.)
  interpreter.invoke()  # this will throw RuntimeError since input,output

参数:

  • tensor_index:得到的张量的张量指数。该值可以从get_output_details中的'index'字段获得。

返回:

一个函数,可以在任何点返回指向内部TFLite张量状态的新numpy数组。永久保持该函数是安全的,但永久保持numpy阵列是不安全的。


以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号