Queue, Coordinator and Producer in Tensorflow

直接用代码 举例说明。

先准备数据:

1
2
3
4
5
6
7
8
9
10
import tensorflow as tf
import numpy as np
num_examples = 10
num_features = 2
data = np.reshape(np.arange(num_examples*num_features), (num_examples,num_features))
tdata = tf.constant(data,'float32')
#print (data)
print ('--------------')
sess = tf.Session()
coord = tf.train.Coordinator()

运行以下各段代码前先运行上方的代码。

以下各段代码之间互相独立。

How to use queue?

1
2
3
4
5
6
q = tf.FIFOQueue(num_examples*5,'float')
inc_q = q.enqueue_many(tdata) #ok! As long as queue is not full
x = q.dequeue() # ok! As long as queue is not empty
sess.run(tf.global_variables_initializer())
sess.run(inc_q)
sess.run(x) # array([0., 1.], dtype=float32)

当队列满或者空时会阻塞。

How to use coordinator?

Coordinator是用来协调各个线程的。 引用tensorflow官网上的说明:

Any of the threads can call coord.request_stop() to ask for all the threads to stop. To cooperate with the requests, each thread must check for coord.should_stop() on a regular basis. coord.should_stop() returns True as soon as coord.request_stop() has been called.

1
2
3
4
5
6
7
8
9
10
11
q = tf.FIFOQueue(num_examples,'float')
inc_q = q.enqueue_many(tdata)
x = q.dequeue()
qr = tf.train.QueueRunner(q,[inc_q] * 4)
enq_threads = qr.create_threads(sess,coord=coord,start=True) #创建并执行4个线程
for iter in range(10):
r = sess.run(x)
print (r) # [[0, 1],[1, 2] ... [18, 19]]
coord.request_stop()
coord.join(enq_threads) #q is closed. Instead of block, a error raises if there is no more elements in q.3

在调用coord.request_stop() 并且等待 enq_threads都终止之后,如果我们继续 一直执行出队操作x,那么当队列空时会产生OutOfRangeError。

How to use producer?

Tensorflow 中提供了各种producer,让我们不必直接控制线程的创建以及队列的出队入队。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
str_q = tf.train.string_input_producer(['a','b','c','d','e'],num_epochs=3)
sess.run(tf.local_variables_initializer()) #MUST DO when num_epochs != None
y = str_q.size()
sess.run(y) # 0 , 此时队列空!
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
sess.run(str_q.size()) # 15 = 5 * 3
try:
while not coord.should_stop():
r = sess.run(str_q.dequeue())
print (r)
except tf.errors.OutOfRangeError:
print('no more inputs.')
finally:
coord.request_stop()
coord.join(threads)

Produer的作用是自动入队,返回值是出队的Tensor?

How to use tf.train.batch?

1
2
3
4
5
6
7
8
9
10
data_q = tf.train.slice_input_producer([tdata], num_epochs=2, shuffle=False,capacity=512)
sess.run(tf.local_variables_initializer()) # MUST DO when num_epochs != None!
#batch_data = tf.train.batch([tdata],batch_size=2,enqueue_many=True) #也可直接用Tensor初始化
batch_data = tf.train.batch([data_q],batch_size=2,enqueue_many=True)
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
sess.run(batch_data)
sess.run(batch_data)

当队列空时,抛出OutOfRangeError。如果用Tensor初始化batch,那么永远不会抛出OutOfRangeError。