提问者:小点点

无法获取张量的值


当运行MNIST数据集时,我想知道在训练批处理期间我的模型实际上输出了什么。这是我的代码:(我没有添加优化器和损失函数):

import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data

INPUT_NODE  = 784 # the total pixels of the input images
OUTPUT_NODE = 10  # the output varies from 0 to 9
LAYER_NODE = 500
BATCH_SIZE = 100
TRAINING_STEPS = 10

def inference(input_tensor, avg_class, weight1, biase1, weight2, biase2):
    if avg_class == None:
        layer = tf.nn.relu(tf.matmul(input_tensor, weight1) + biase1)
        return tf.matmul(layer, weight2)+biase2
    else:
        layer = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weight1)) + 
                avg_class.average(biase1))
        return tf.matmul(layer, avg_class.average(weight2)) + avg_class.average(biase2)


def train(mnist):
    x = tf.placeholder(tf.float32, [None, INPUT_NODE], name = 'x-input')
    y = tf.placeholder(tf.float32, [None, OUTPUT_NODE],name = 'y-input')

    weight1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER_NODE], stddev = 0.1))
    biase1  = tf.Variable(tf.constant(0.1, shape = [LAYER_NODE]))
    weight2 = tf.Variable(tf.truncated_normal([LAYER_NODE, OUTPUT_NODE], stddev = 0.1))
    biase2  = tf.Variable(tf.constant(0.1, shape = [OUTPUT_NODE]))

    out = inference(x, None, weight1, biase1, weight2, biase2)

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        validate_feed = {x:mnist.validation.images, y:mnist.validation.labels}
        test_feed = {x:mnist.test.images, y:mnist.test.labels}

        for i in range(TRAINING_STEPS):

            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            sess.run(out, feed_dict= {x:xs, y:ys})
            print(out)

def main(arg = None):
    mnist = input_data.read_data_sets("/home/vincent/Tensorflow/MNIST/data/", one_hot = True)
    train(mnist)

if __name__ == '__main__':
    tf.app.run()

我试图打印出:

张量(“加法1:0”,形状=(?,10),数据类型=浮点32)

如果我想知道out的价值,我应该怎么做?我试图print(out.eval()),它引发了错误


共1个答案

匿名用户

out是一个张量对象。如果要获取其值,请替换

sess.run(out, feed_dict= {x:xs, y:ys})
print(out)

具有

res_out=sess.run(out, feed_dict= {x:xs, y:ys})
print(res_out)