提问者:小点点

如何在Tensorflow JS中构建非线性网络?


如何在Tensorflow JS中构建一个接受一个输入和一个标签的模型?

以下示例演示了该场景:

x=1, y=1;

x=2, y=4;

x=3, y=9;

x=4, y=16;

到目前为止,我只能让线性模型工作。


共2个答案

匿名用户

您需要添加与线性不同的激活。这里可能的激活函数的完整列表

(async() => {
const model = tf.sequential();
 model.add(tf.layers.dense({units: 8, inputShape: [1]}));
 model.add(tf.layers.dense({units: 4, activation: 'relu'}));
 model.add(tf.layers.dense({units: 1, activation: 'relu'}));
 // Note that the untrained model is random at this point.
 const x = tf.tensor([[1], [2], [3]])
 const y = tf.tensor([[1], [4], [9]])
 model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
 const h = await model.fit(x, y, {epochs: 100})
 
 const t = model.predict(x);
 t.print()
})()
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"> </script>

匿名用户

有两个好答案的ensorflow标签上有相同的问题(“神经网络预测第n个平方”)。我将尝试使代码适应Tensorflow. js,并重复最重要的部分作为简短的答案。请阅读那边的答案以获得完整的图片。

简而言之:这个任务对于神经网络来说并不容易,你需要一个复杂的神经网络来完成这项工作。此外,你需要一个像ReLU这样的激活函数,这样输出的范围就不限于特定的范围(比如sigmoid在-1和1之间)。

在下面的网络中,我正在改编Miriam Farber答案中的代码,创建一个具有三个密集层(50个单元)和一个输出层(一个单元)的神经网络。所有层都使用ReLU激活函数。

(async () => {
  // 1000 random values between -10 and 10
  const xArray = (new Array(1000)).fill(0).map(() => [20 * (Math.random() - 0.5)]);
  const yArray = xArray.map(x => x[0] * x[0]);
  const x = tf.tensor(xArray);
  const y = tf.tensor(yArray);

  const model = tf.sequential();
  model.add(tf.layers.dense({ units: 50, activation: 'relu', inputShape: [1] }));
  model.add(tf.layers.dense({ units: 50, activation: 'relu' }));
  model.add(tf.layers.dense({ units: 50, activation: 'relu' }));
  model.add(tf.layers.dense({ units: 1, activation: 'relu' }));

  model.compile({
      optimizer: tf.train.adam(0.01),
      loss: tf.losses.meanSquaredError,
  });
  await model.fit(x, y, {
    epochs: 10,
    callbacks: {
        onEpochEnd: (epoch, logs) => console.log(`Epoch ${epoch}, loss: ${logs.loss}`),
    }
  });
  const prediction = model.predict(tf.tensor2d([[-5], [0], [5], [9], [20]]));
  console.log('Results for: -5, 0, 5, 9, 20');
  prediction.print();
})();
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.2.7/dist/tf.min.js"></script>