目前我正在创建一个Unet来进行图像分割。在我的Unet中,最后一个输出通道是183,这是类的大小。当我做训练时,它显示以下错误。我尝试了BCEWellLogitsLoss函数,它对我有用,但它没有显示正确的输出图像,它慢慢训练成黑色图像,而不是GrayScale类图像。预期的图像输出图像
def Append_files():
train_folder = "train/"
train_label_folder = "train_label/"
root_dir = "data/Train_Data/"
ext_cond = lambda filename:filename.endswith('.png')
train_data = os.path.join(root_dir,train_folder)
train_data_path = lambda filename: True
train_data_label = os.path.join(root_dir,train_label_folder)
train_data_label_path = lambda filename: True
filtered_train_data = []
filtered_train_label_data = []
for path, _, files in os.walk(train_data):
files.sort()
for file in files:
if train_data_path(file) and ext_cond(file):
full_file_path = os.path.join(path,file)
filtered_train_data.append(full_file_path)
for path,_,files in os.walk(train_data_label):
files.sort()
for file in files:
if train_data_label_path(file) and ext_cond(file):
full_file_path = os.path.join(path,file)
filtered_train_label_data.append(full_file_path)
return filtered_train_data, filtered_train_label_data
def open_image(train_data,train_data_label):
data = Image.open(train_data)
data = np.array(data)
label = Image.open(train_data_label)
label = np.array(label)
return data,label
def load_dataset(train_data,train_label_data):
train_set_loader = transforms.Compose([transforms.ToPILImage(),transforms.CenterCrop(size=(572,572)),transforms.ToTensor()])
train_label_set_loader = transforms.Compose([transforms.ToPILImage(),transforms.CenterCrop(size=(388,388)),transforms.ToTensor()])
train_set = train_set_loader(train_data)
train_label_set = train_label_set_loader(train_label_data)
return train_set, train_label_set
### Main
if __name__ == '__main__':
epochs = 100
loss = 0
train_data,train_label_data = Append_files()
data_len = len(train_data)
model = UNet(3,183).to(device)
#loss_func = nn.BCEWithLogitsLoss()
loss_func = nn.BCELoss()
get_optimizer = torch.optim.Adam(model.parameters(),0.001,betas=(0.9,0.999),eps=0.1,weight_decay = 0,amsgrad=False)
for epoch in range(epochs):
print ('starting epoch{}/{}.'.format(epoch+1,epochs))
for i in range(data_len):
data, label = open_image(train_data[i],train_label_data[i])
train, label = load_dataset(data,label)
train = Variable(train).to(device)
label = Variable(label).to(device)
train = train.unsqueeze(0)
label = label.unsqueeze(0)
get_optimizer.zero_grad() #train.shape = 1,1,388,388
train_real_image = model.forward(train) #train_real_image.shape = 1,1,388,388
#I can change it to 1,183,388 where 183 is the class size
#Loss Function
#Label.shape = 1,1,388,388
loss = loss_func(train_real_image,label)
print (loss)
loss.backward()
get_optimizer.step()
/opt/conda/conda-bld/pytorch_1556653183467/work/aten/src/THCUNN/BCECriterion.cu:42: Acctype bce_functor<Dtype, Acctype>::operator()(Tuple) [with Tuple = thrust::detail::tuple_of_iterator_references<thrust::device_reference<float>, thrust::device_reference<float>, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>, Dtype = float, Acctype = float]: block: [23,0,0], thread: [28,0,0] Assertion `input >= 0. && input <= 1.` failed.
/opt/conda/conda-bld/pytorch_1556653183467/work/aten/src/THCUNN/BCECriterion.cu:42: Acctype bce_functor<Dtype, Acctype>::operator()(Tuple) [with Tuple = thrust::detail::tuple_of_iterator_references<thrust::device_reference<float>, thrust::device_reference<float>, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>, Dtype = float, Acctype = float]: block: [23,0,0], thread: [29,0,0] Assertion `input >= 0. && input <= 1.` failed.
/opt/conda/conda-bld/pytorch_1556653183467/work/aten/src/THCUNN/BCECriterion.cu:42: Acctype bce_functor<Dtype, Acctype>::operator()(Tuple) [with Tuple = thrust::detail::tuple_of_iterator_references<thrust::device_reference<float>, thrust::device_reference<float>, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type, thrust::null_type>, Dtype = float, Acctype = float]: block: [23,0,0], thread: [31,0,0] Assertion `input >= 0. && input <= 1.` failed.
RuntimeError: reduce failed to synchronize: device-side assert triggered
nn. BCELoss
期望我们已经对logits应用了sigmoid激活,而nn.BCEWellLogitsLoss
期望logits作为输入,并在计算二进制交叉熵损失之前在内部对logits应用sigmoid激活。您可以查看此讨论。
因此,如果您想使用nn. BCELoss
,请确保在将日志发送到loss_func
之前对日志应用激活函数。