Java源码示例:org.deeplearning4j.zoo.PretrainedType
示例1
public static ComputationGraph loadModel() throws IOException {
ZooModel zooModel = new org.deeplearning4j.zoo.model.VGG16();
ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET);
ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(vgg16)
.removeVertexKeepConnections("block5_pool")
.addLayer("block5_pool", avgPoolLayer(), "block5_conv3")
.removeVertexKeepConnections("block4_pool")
.addLayer("block4_pool", avgPoolLayer(), "block4_conv3")
.removeVertexKeepConnections("block3_pool")
.addLayer("block3_pool", avgPoolLayer(), "block3_conv3")
.removeVertexKeepConnections("block2_pool")
.addLayer("block2_pool", avgPoolLayer(), "block2_conv2")
.removeVertexKeepConnections("block1_pool")
.addLayer("block1_pool", avgPoolLayer(), "block1_conv2")
.setInputTypes(InputType
.convolutionalFlat(NeuralStyleTransfer.HEIGHT, NeuralStyleTransfer.WIDTH, NeuralStyleTransfer.CHANNELS))
.removeVertexAndConnections("fc2")
.removeVertexAndConnections("fc1")
.removeVertexAndConnections("flatten")
.removeVertexAndConnections("predictions")
.setOutputs("block5_pool")
.build();
vgg16Transfer.initGradientsView();
System.out.println("vgg16Transfer.summary() = " + vgg16Transfer.summary());
return vgg16Transfer;
}
示例2
private ComputationGraph loadModel() throws IOException {
ZooModel zooModel = VGG16.builder().build();
ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET);
vgg16.initGradientsView();
log.info(vgg16.summary());
return vgg16;
}
示例3
private static PretrainedType inferPretrainedTypeFromShape(final long[] shape) {
if (shape.length <= 2 || (shape.length == 3 && shape[2] == 1)) {
// Grayscale
return PretrainedType.MNIST;
} else if (shape[0] > 32 || shape[1] > 32) {
// Large RGB images
return PretrainedType.IMAGENET;
} else {
// Small RGB images
return PretrainedType.CIFAR10;
}
}
示例4
public String pretrainedUrl(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET) {
return KerasConstants.Locations.get(modelPrettyName());
} else {
return null;
}
}
示例5
public long pretrainedChecksum(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET) {
return KerasConstants.Checksums.get(modelPrettyName());
} else {
return 0L;
}
}
示例6
public String pretrainedUrl(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET) {
return KerasConstants.Locations.get(modelPrettyName());
} else {
return null;
}
}
示例7
public long pretrainedChecksum(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET) {
return KerasConstants.Checksums.get(modelPrettyName());
} else {
return 0L;
}
}
示例8
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return DL4JResources.getURLString("models/tiny-yolo-voc_dl4j_inference.v2.zip");
else
return null;
}
示例9
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return 1256226465L;
else
return 0L;
}
示例10
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return DL4JResources.getURLString("models/yolo2_dl4j_inference.v3.zip");
else
return null;
}
示例11
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return 3658373840L;
else
return 0L;
}
示例12
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return DL4JResources.getURLString("models/vgg19_dl4j_inference.zip");
else
return null;
}
示例13
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return 2782932419L;
else
return 0L;
}
示例14
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return DL4JResources.getURLString("models/nasnetmobile_dl4j_inference.v1.zip");
else if (pretrainedType == PretrainedType.IMAGENETLARGE)
return DL4JResources.getURLString("models/nasnetlarge_dl4j_inference.v1.zip");
else
return null;
}
示例15
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return 3082463801L;
else if (pretrainedType == PretrainedType.IMAGENETLARGE)
return 321395591L;
else
return 0L;
}
示例16
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.SEGMENT)
return DL4JResources.getURLString("models/unet_dl4j_segment_inference.v1.zip");
else
return null;
}
示例17
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.SEGMENT)
return 712347958L;
else
return 0L;
}
示例18
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.MNIST)
return DL4JResources.getURLString("models/lenet_dl4j_mnist_inference.zip");
else
return null;
}
示例19
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.MNIST)
return 1906861161L;
else
return 0L;
}
示例20
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
if (inputShape[1] == 448 && inputShape[2] == 448)
return DL4JResources.getURLString("models/darknet19_448_dl4j_inference.v2.zip");
else
return DL4JResources.getURLString("models/darknet19_dl4j_inference.v2.zip");
else
return null;
}
示例21
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
if (inputShape[1] == 448 && inputShape[2] == 448)
return 1054319943L;
else
return 691100891L;
else
return 0L;
}
示例22
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return DL4JResources.getURLString("models/vgg16_dl4j_inference.zip");
else if (pretrainedType == PretrainedType.CIFAR10)
return DL4JResources.getURLString("models/vgg16_dl4j_cifar10_inference.v1.zip");
else if (pretrainedType == PretrainedType.VGGFACE)
return DL4JResources.getURLString("models/vgg16_dl4j_vggface_inference.v1.zip");
else
return null;
}
示例23
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return 3501732770L;
if (pretrainedType == PretrainedType.CIFAR10)
return 2192260131L;
if (pretrainedType == PretrainedType.VGGFACE)
return 2706403553L;
else
return 0L;
}
示例24
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return DL4JResources.getURLString("models/squeezenet_dl4j_inference.v2.zip");
else
return null;
}
示例25
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return 3711411239L;
else
return 0L;
}
示例26
@Override
public ComputationGraph initPretrained(PretrainedType pretrainedType) throws IOException {
ComputationGraph cg = (ComputationGraph) super.initPretrained(pretrainedType);
//Set collapse dimensions to true in global avg pooling - more useful for users [N,1000] rather than [N,1000,1,1] out. Also matches non-pretrain config
((GlobalPoolingLayer)cg.getLayer("global_average_pooling2d_5").conf().getLayer()).setCollapseDimensions(true);
return cg;
}
示例27
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return DL4JResources.getURLString("models/resnet50_dl4j_inference.v3.zip");
else
return null;
}
示例28
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return 3914447815L;
else
return 0L;
}
示例29
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return DL4JResources.getURLString("models/xception_dl4j_inference.v2.zip");
else
return null;
}
示例30
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
if (pretrainedType == PretrainedType.IMAGENET)
return 3277876097L;
else
return 0L;
}