点燃:如何保存和重新加载经过训练的模型
ignite: how to save and re-load trained model
以下是我用来训练模型的一段代码。在那之后,除了 FileExporter class 之外,我如何以及在哪里可以保存我的模型并读回它?它只是在一个文件中还是我可以将它存储在缓存中并访问回来?
IgniteCache<Integer, double[]> cache = ignite.getOrCreateCache("MLData_IRIS");
// extracting sepal length, sepal width, petal length, petal width
IgniteBiFunction<Integer, double[], Vector> featureExtractor = new RangeExtractor(1, 5);
IgniteBiFunction<Integer, double[], Double> labelExtractor = new PointExtractor(0);
System.out.println(">>> Create new training dataset splitter object.");
TrainTestSplit<Integer, double[]> split = new TrainTestDatasetSplitter<Integer, double[]>()
.split(0.5, 0.5);
IgniteBiPredicate<Integer, double[]> testData = split.getTestFilter();
IgniteBiPredicate<Integer, double[]> trainData = split.getTrainFilter();
// Set up the trainer
KMeansTrainer trainer = new KMeansTrainer()
.withDistance(new EuclideanDistance()) //other metrics are HammingDistance, ManhattanDistance
.withAmountOfClusters(3) // number clusters want to create
.withMaxIterations(100)
.withEpsilon(1.0E-4D)
.withSeed(1234L);
long t1 = System.currentTimeMillis();
KMeansModel mdl = trainer.fit(
ignite,
cache,
trainData,
featureExtractor,
labelExtractor
);
long t2 = System.currentTimeMillis();
System.out.println("time taken to build the model : " + (t2 - t1) + " ms");
System.out.println(">>> --------------------------------------------");
System.out.println(">>> trained model: " + mdl.toString(true));
目前 Ignite 只有这个机制 - FileExporter。
但是,对于 2.8 版,我们已经实现了模型存储。
模型保存示例:
ModelStorage storage = new ModelStorageFactory().getModelStorage(ignite);
storage.mkdirs("/");
storage.putFile("/my_model", serializedMdl);
ModelDescriptor desc = new ModelDescriptor(
"MyModel",
"My Cool Model",
new ModelSignature("", "", ""),
new ModelStorageModelReader("/my_model"),
new IgniteModelParser<>()
);
ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
descStorage.put("my_model", desc);
加载模型示例:
Ignite ignite = Ignition.ignite();
ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
ModelDescriptor desc = descStorage.get(mdl);
Model<byte[], byte[]> infMdl = new SingleModelBuilder().build(desc.getReader(), desc.getParser());
Vector input = VectorUtils.of(x);
try {
return deserialize(infMdl.predict(serialize(input)));
}
catch (IOException | ClassNotFoundException e) {
throw new RuntimeException(e);
}
其中 x - 是双精度向量,mdl - 是模型名称。
注意:此 API 将在 2.8 版中可用。但是,如果您要从 master 分支构建 Ignite,您现在可以尝试一下。
以下是我用来训练模型的一段代码。在那之后,除了 FileExporter class 之外,我如何以及在哪里可以保存我的模型并读回它?它只是在一个文件中还是我可以将它存储在缓存中并访问回来?
IgniteCache<Integer, double[]> cache = ignite.getOrCreateCache("MLData_IRIS");
// extracting sepal length, sepal width, petal length, petal width
IgniteBiFunction<Integer, double[], Vector> featureExtractor = new RangeExtractor(1, 5);
IgniteBiFunction<Integer, double[], Double> labelExtractor = new PointExtractor(0);
System.out.println(">>> Create new training dataset splitter object.");
TrainTestSplit<Integer, double[]> split = new TrainTestDatasetSplitter<Integer, double[]>()
.split(0.5, 0.5);
IgniteBiPredicate<Integer, double[]> testData = split.getTestFilter();
IgniteBiPredicate<Integer, double[]> trainData = split.getTrainFilter();
// Set up the trainer
KMeansTrainer trainer = new KMeansTrainer()
.withDistance(new EuclideanDistance()) //other metrics are HammingDistance, ManhattanDistance
.withAmountOfClusters(3) // number clusters want to create
.withMaxIterations(100)
.withEpsilon(1.0E-4D)
.withSeed(1234L);
long t1 = System.currentTimeMillis();
KMeansModel mdl = trainer.fit(
ignite,
cache,
trainData,
featureExtractor,
labelExtractor
);
long t2 = System.currentTimeMillis();
System.out.println("time taken to build the model : " + (t2 - t1) + " ms");
System.out.println(">>> --------------------------------------------");
System.out.println(">>> trained model: " + mdl.toString(true));
目前 Ignite 只有这个机制 - FileExporter。
但是,对于 2.8 版,我们已经实现了模型存储。
模型保存示例:
ModelStorage storage = new ModelStorageFactory().getModelStorage(ignite);
storage.mkdirs("/");
storage.putFile("/my_model", serializedMdl);
ModelDescriptor desc = new ModelDescriptor(
"MyModel",
"My Cool Model",
new ModelSignature("", "", ""),
new ModelStorageModelReader("/my_model"),
new IgniteModelParser<>()
);
ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
descStorage.put("my_model", desc);
加载模型示例:
Ignite ignite = Ignition.ignite();
ModelDescriptorStorage descStorage = new ModelDescriptorStorageFactory().getModelDescriptorStorage(ignite);
ModelDescriptor desc = descStorage.get(mdl);
Model<byte[], byte[]> infMdl = new SingleModelBuilder().build(desc.getReader(), desc.getParser());
Vector input = VectorUtils.of(x);
try {
return deserialize(infMdl.predict(serialize(input)));
}
catch (IOException | ClassNotFoundException e) {
throw new RuntimeException(e);
}
其中 x - 是双精度向量,mdl - 是模型名称。
注意:此 API 将在 2.8 版中可用。但是,如果您要从 master 分支构建 Ignite,您现在可以尝试一下。