提问者:小点点

pred=model.predict_classes([准备(file_path)])属性错误:函数对象没有属性predict_classes


我试图加载在我的keras模型上修改蝴蝶物种分类器在tkinter我认为问题在于我如何训练我的模型

import cv2
import tensorflow as tf

CATEGORIES = ["Abyssinians", "American Shorthair", "Bengals", "Birman",
              "British Shorthairs", "Devon Rex", "Exotic Shorthairs", "Maine Coon",
              "Oriental Shorthairs", "Persians", "Ragdoll", "Scottish Folds", "Siamese", "Somali", "Sphynx"]  # will use this to convert prediction num to string value
def prepare(filepath):
    IMG_SIZE = 100 # 50 in txt-based
    img_array = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)  # read in the image, convert to grayscale
    new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE))  # resize image to match model's expected sizing
    return new_array.reshape(-1, IMG_SIZE, IMG_SIZE, 1)  # return the image with shaping that TF wants.

model = tf.keras.models.load_model("CAT_BREEDS.model")

prediction = model.predict([prepare(r'D:\Desktop\CATS\validation\Abyssinians\45997693_52.jpg')])
print(prediction)

上面是我用来训练keras模型的代码,但是当我试图使用蝴蝶分类器预测一个类时,我得到了这个错误

pred=model.predict_classes([准备(file_path)])属性错误:函数对象没有属性predict_classes

导入Numpy作为np从tenorflow导入keras从tensorflow.keras.layers导入Dense, GlobalAveragePooling2D从tensorflow.keras.optimizers导入Adam从tensorflow.keras.preprocessing.image导入ImageDataGenerator从tensorflow.keras.models导入Model从sklearn.metrics导入confusion_matrix导入迭代工具导入matplotlib.pyplot作为plt

train\u path=r'D:\Desktop\CATS-Copy 2\train'valid\u path=r'D:\Desktop\CATS-Copy 2\validation'test\u path=r'D:\Desktop\CATS-Copy 2\test'

class_labels=[阿比西尼亚人、美国短毛人、孟加拉人、比尔曼人、英国短毛人、德文雷克斯人、异国短毛人、缅因库恩人、东方短毛人、波斯人、布娃娃、苏格兰折叠人、暹罗人、索马里人、狮身人面像]

train\u batches=ImageDataGenerator(预处理函数=keras.applications.xception.preprocess\u input)
.flow\u from\u目录(train\u路径,target\u size=(299299),classes=class\u标签,batch\u size=5)valid\u batches=ImageDataGenerator(预处理函数=keras.applications.xception.preprocess\u input)
.flow\u from\u目录(valid\u路径,target\u大小=(299299),类=类标签,批次大小=5)测试批次=图像数据生成器(预处理函数=keras.applications.xception.preprocess\u input)
从目录(测试路径,目标大小=(299299),类=类标签,批次大小=5,随机播放=False)

base_model=keras.applications.xception.xception(include_top=False)

x=base_model.output x=globalaveragepoolig2d()(x)x=density(1024,activation='relu')(x)x=density(15,activation='sigmoid')(x)model=model(inputs=base_model.input,outputs=x)

base_model.trainable=假

N=1

model.compile(Adam(lr=.0001),损失='categorical_crossentropy',指标=['准确性'])历史=model.fit_generator(train_batches,steps_per_epoch=200,validation_data=valid_batches,validation_steps=90,纪元=N,详细=1)

model_json=model.to_json(),打开(“model.json”,“w”)作为json_文件:json_文件。写入(model_json)模型。保存_权重('model_weights.h5'))

打印(“[INFO]评估模型…”)

test_labels=test_batches.classes预测=model.predict_generator(test_batches,步骤=28,冗长=1)

model.save(CAT_BREEDS模型)


共1个答案

匿名用户

我复制了你的代码。因为我没有你的数据集,我用了我自己的2类。一个错误是行x=密集(15,激活='sigmoid')(x)。因为您正在进行分类,您的激活应该是激活='softmax'。其余的代码似乎运行正常