Java源码示例:org.deeplearning4j.zoo.PretrainedType

示例1
public static ComputationGraph loadModel() throws IOException {
    ZooModel zooModel = new org.deeplearning4j.zoo.model.VGG16();
    ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET);
    ComputationGraph vgg16Transfer = new TransferLearning.GraphBuilder(vgg16)
            .removeVertexKeepConnections("block5_pool")
            .addLayer("block5_pool", avgPoolLayer(), "block5_conv3")
            .removeVertexKeepConnections("block4_pool")
            .addLayer("block4_pool", avgPoolLayer(), "block4_conv3")
            .removeVertexKeepConnections("block3_pool")
            .addLayer("block3_pool", avgPoolLayer(), "block3_conv3")
            .removeVertexKeepConnections("block2_pool")
            .addLayer("block2_pool", avgPoolLayer(), "block2_conv2")
            .removeVertexKeepConnections("block1_pool")
            .addLayer("block1_pool", avgPoolLayer(), "block1_conv2")
            .setInputTypes(InputType
                    .convolutionalFlat(NeuralStyleTransfer.HEIGHT, NeuralStyleTransfer.WIDTH, NeuralStyleTransfer.CHANNELS))
            .removeVertexAndConnections("fc2")
            .removeVertexAndConnections("fc1")
            .removeVertexAndConnections("flatten")
            .removeVertexAndConnections("predictions")
            .setOutputs("block5_pool")
            .build();
    vgg16Transfer.initGradientsView();
    System.out.println("vgg16Transfer.summary() = " + vgg16Transfer.summary());
    return vgg16Transfer;
}
 
示例2
private ComputationGraph loadModel() throws IOException {
    ZooModel zooModel = VGG16.builder().build();
    ComputationGraph vgg16 = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET);
    vgg16.initGradientsView();
    log.info(vgg16.summary());
    return vgg16;
}
 
示例3
private static PretrainedType inferPretrainedTypeFromShape(final long[] shape) {

        if (shape.length <= 2 || (shape.length == 3 && shape[2] == 1)) {
            // Grayscale
            return PretrainedType.MNIST;
        } else if (shape[0] > 32 || shape[1] > 32) {
            // Large RGB images
            return PretrainedType.IMAGENET;
        } else {
            // Small RGB images
            return PretrainedType.CIFAR10;
        }
    }
 
示例4
public String pretrainedUrl(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET) {
        return KerasConstants.Locations.get(modelPrettyName());
    } else {
        return null;
    }
}
 
示例5
public long pretrainedChecksum(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET) {
        return KerasConstants.Checksums.get(modelPrettyName());
    } else {
        return 0L;
    }
}
 
示例6
public String pretrainedUrl(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET) {
        return KerasConstants.Locations.get(modelPrettyName());
    } else {
        return null;
    }
}
 
示例7
public long pretrainedChecksum(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET) {
        return KerasConstants.Checksums.get(modelPrettyName());
    } else {
        return 0L;
    }
}
 
示例8
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return DL4JResources.getURLString("models/tiny-yolo-voc_dl4j_inference.v2.zip");
    else
        return null;
}
 
示例9
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return 1256226465L;
    else
        return 0L;
}
 
示例10
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return DL4JResources.getURLString("models/yolo2_dl4j_inference.v3.zip");
    else
        return null;
}
 
示例11
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return 3658373840L;
    else
        return 0L;
}
 
示例12
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return DL4JResources.getURLString("models/vgg19_dl4j_inference.zip");
    else
        return null;
}
 
示例13
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return 2782932419L;
    else
        return 0L;
}
 
示例14
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return DL4JResources.getURLString("models/nasnetmobile_dl4j_inference.v1.zip");
    else if (pretrainedType == PretrainedType.IMAGENETLARGE)
        return DL4JResources.getURLString("models/nasnetlarge_dl4j_inference.v1.zip");
    else
        return null;
}
 
示例15
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return 3082463801L;
    else if (pretrainedType == PretrainedType.IMAGENETLARGE)
        return 321395591L;
    else
        return 0L;
}
 
示例16
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.SEGMENT)
        return DL4JResources.getURLString("models/unet_dl4j_segment_inference.v1.zip");
    else
        return null;
}
 
示例17
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.SEGMENT)
        return 712347958L;
    else
        return 0L;
}
 
示例18
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.MNIST)
        return DL4JResources.getURLString("models/lenet_dl4j_mnist_inference.zip");
    else
        return null;
}
 
示例19
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.MNIST)
        return 1906861161L;
    else
        return 0L;
}
 
示例20
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        if (inputShape[1] == 448 && inputShape[2] == 448)
            return DL4JResources.getURLString("models/darknet19_448_dl4j_inference.v2.zip");
        else
            return DL4JResources.getURLString("models/darknet19_dl4j_inference.v2.zip");
    else
        return null;
}
 
示例21
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        if (inputShape[1] == 448 && inputShape[2] == 448)
            return 1054319943L;
        else
            return 691100891L;
    else
        return 0L;
}
 
示例22
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return DL4JResources.getURLString("models/vgg16_dl4j_inference.zip");
    else if (pretrainedType == PretrainedType.CIFAR10)
        return DL4JResources.getURLString("models/vgg16_dl4j_cifar10_inference.v1.zip");
    else if (pretrainedType == PretrainedType.VGGFACE)
        return DL4JResources.getURLString("models/vgg16_dl4j_vggface_inference.v1.zip");
    else
        return null;
}
 
示例23
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return 3501732770L;
    if (pretrainedType == PretrainedType.CIFAR10)
        return 2192260131L;
    if (pretrainedType == PretrainedType.VGGFACE)
        return 2706403553L;
    else
        return 0L;
}
 
示例24
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return DL4JResources.getURLString("models/squeezenet_dl4j_inference.v2.zip");
    else
        return null;
}
 
示例25
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return 3711411239L;
    else
        return 0L;
}
 
示例26
@Override
public ComputationGraph initPretrained(PretrainedType pretrainedType) throws IOException {
    ComputationGraph cg = (ComputationGraph) super.initPretrained(pretrainedType);
    //Set collapse dimensions to true in global avg pooling - more useful for users [N,1000] rather than [N,1000,1,1] out. Also matches non-pretrain config
    ((GlobalPoolingLayer)cg.getLayer("global_average_pooling2d_5").conf().getLayer()).setCollapseDimensions(true);
    return cg;
}
 
示例27
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return DL4JResources.getURLString("models/resnet50_dl4j_inference.v3.zip");
    else
        return null;
}
 
示例28
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return 3914447815L;
    else
        return 0L;
}
 
示例29
@Override
public String pretrainedUrl(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return DL4JResources.getURLString("models/xception_dl4j_inference.v2.zip");
    else
        return null;
}
 
示例30
@Override
public long pretrainedChecksum(PretrainedType pretrainedType) {
    if (pretrainedType == PretrainedType.IMAGENET)
        return 3277876097L;
    else
        return 0L;
}