首页
Preview

如何从Java程序中调用训练好的TensorFlow模型

TensorFlow机器学习模型的主要开发和训练语言是Python。然而,许多企业服务器程序都是用Java编写的。因此,你经常会遇到这样的情况,需要从Java程序中调用你在Python中训练的TensorFlow模型。

如果你在Google Cloud平台上使用Cloud ML Engine,这是没有问题的。在Cloud ML Engine中,预测是通过REST API调用进行的,因此你可以使用任何编程语言来实现。但是,如果你已经下载了TensorFlow模型,想要离线进行预测呢?

下面是如何使用在Python中训练的Tensorflow模型在Java中进行预测的方法。

注意:Tensorflow团队现在已经开始添加Java绑定。请参见https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java 获取详细信息。首先尝试这个,如果它对你不起作用,再来这里……

在Python中将模型文件写出

第一件事是在Python中以两种格式保存TensorFlow模型:(a)将权重、偏置等保存为“saver_def”文件(b)将图本身保存为protobuf文件。为了保持理智,你可能希望将图保存为文本和二进制protobuf格式。阅读文本格式会对你有所帮助,可以找到TensorFlow分配给未显式分配名称的节点的名称。从Python编写这三个文件的代码如下:

# save the graph in protobuf format
tf.train.write_graph(session.graph_def, '.', 'trained_model.pb', as_text=False)

# save the graph in human-readable text format
tf.train.write_graph(session.graph_def, '.', 'trained_model.txt', as_text=True)

# save the weights and biases
saver = tf.train.Saver()
saver.save(session, './trained_model.sd')

在我的情况下,从save_def打印出来的两个魔术字符串是_save/Const:0_和_save/restore_all_——这就是你在Java代码中看到的。如果你的不同,请在编写Java代码时进行更改。

.sd文件包含权重、偏置等(图中变量的实际值)。.proto文件是包含计算图的二进制文件,.txt是相应的文本版本。

在Java中调用Tensorflow C++

即使你可能已经使用Tensorflow在Python中向你的模型提供数据并对其进行训练,Tensorflow Python包实际上调用C++实现来执行实际的工作。因此,我们可以使用Java Native Interface(JNI)直接调用C++,并使用C++从Java创建图形并从模型恢复权重和偏置。

与其手动编写所有JNI调用,不如使用一个名为JavaCpp的开源库来完成。要使用JavaCpp,请将以下依赖项添加到Java Maven pom.xml中:

<dependency>
  <groupId>org.bytedeco</groupId>
  <artifactId>tensorflow-presets</artifactId>
  <version>1.15.0-1.4.4</version>
</dependency>

如果你使用其他构建管理系统,请将Tensorflow的Javacpp预设和它的所有依赖项添加到你的应用程序类路径中。

在Java中创建模型

在Java代码中,按照以下方式读取proto文件以创建Graph定义(省略了导入):

GraphDef graphDef = GraphDef.parseFrom(Files.readAllBytes(Paths.get("trained_model.pb")));
Graph graph = new Graph();
graph.importGraphDef(graphDef);

接下来,使用Session::Run()从保存的模型文件中恢复权重和偏置。注意saver_def中的魔术字符串如何使用。

try (Session session = new Session(graph)) {
  // Restore variables from disk.
  session.run(new String[]{"save/restore_all"}, new Tensor[]{}, new ArrayList<>());
}

在Java中进行预测

此时,你的模型已经准备好了。你现在可以使用它进行预测。这类似于在Python中的做法——你必须为所有占位符传递值并评估输出节点。不同之处在于,你必须知道占位符和输出节点的实际名称。如果你没有在Python中为这些节点分配唯一的名称,Tensorflow会为它们分配名称。你可以查看写出的trained_model.txt文件来查找它们的名称。或者,你可以回到Python代码中,为关键节点分配你记得的名称。在我的情况下,输入占位符称为Placeholder;dropout节点占位符称为Placeholder_2,输出节点称为Sigmoid。你将在下面的Session::Run()调用中看到这些引用。

在我的情况下,神经网络使用5个预测变量。假设我有一个输入数组,这些变量是我神经网络模型的预测变量,并且想要对2组这样的输入进行预测,那么我的输入是一个2x5矩阵。我的神经网络只有一个输出,因此对于2组输入,输出张量是一个2x1矩阵。dropout节点被赋予了1.0的硬编码输入(在预测中,我们保留所有节点——dropout概率仅用于训练)。所以,我有:

float[][] input = new float[][]{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}};
try (Tensor<Float> inputTensor = Tensors.create(input);
    Session session = new Session(graph)) {
  Tensor<Float> result = session.runner()
      .feed("Placeholder", inputTensor)
      .feed("Placeholder_2", Tensors.create(1.0f))
      .fetch("Sigmoid").run().get(0).expect(Float.class);
  System.out.println(Arrays.deepToString(result.copyTo(new float[2][1]))); // [[0.3640868], [0.9851471]]
}

就这样——你现在使用Java进行预测。有几个步骤,但这是可以预料的,当你混合3种编程语言(Python、C++和Java)时。但重要的是,它是可以做到的,并且相对简单。

当然,这样做并没有利用硬件加速和分布式计算。如果你想以非常高的速率实时进行预测,你应该考虑使用Cloud ML Engine。

版权声明:本文内容由TeHub注册用户自发贡献,版权归原作者所有,TeHub社区不拥有其著作权,亦不承担相应法律责任。 如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

点赞(0)
收藏(0)
阿波
The minute I see you, I want your clothes gone!

评论(0)

添加评论