From d6b237d75dbe72f3be86747dc8e8824ee3d7cebf Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 14 Jan 2021 12:08:02 -0500 Subject: [PATCH 1/3] Initial Checkin --- .../framework/constraints/Constraint.java | 95 ++++++++++ .../framework/constraints/MaxNorm.java | 115 +++++++++++++ .../framework/constraints/MinMaxNorm.java | 162 ++++++++++++++++++ .../framework/constraints/NonNeg.java | 46 +++++ .../framework/constraints/UnitNorm.java | 90 ++++++++++ .../framework/constraints/MaxNormTest.java | 65 +++++++ .../framework/constraints/MinMaxNormTest.java | 65 +++++++ .../framework/constraints/NonNegTest.java | 110 ++++++++++++ .../framework/constraints/UnitNormTest.java | 66 +++++++ .../org/tensorflow/framework/utils/ND.java | 101 ++++++++++- .../framework/utils/TestSession.java | 14 +- 11 files changed, 919 insertions(+), 10 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java new file mode 100644 index 00000000000..bf6f97b463a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java @@ -0,0 +1,95 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Base class for Constraints. Constraint subclasses impose constraints on weight values + * + * @param the date type for the weights + */ +public abstract class Constraint { + + public static final float EPSILON = 1e-7f; + + private final Ops tf; + + /** + * Creates a Constraint + * + * @param tf the TensorFlow Ops + */ + public Constraint(Ops tf) { + this.tf = tf; + } + + /** + * Applies the constraint against the provided weights + * + * @param weights the weights + * @return the constrained weights + */ + public abstract Operand call(Operand weights); + + /** + * Gets the TensorFlow Ops + * + * @return the TensorFlow Ops + */ + public Ops getTF() { + return tf; + } + + /** + * Get the element-wise square root. + * + * @param x the input Operand. + * @return the element-wise square root. + */ + protected Operand sqrt(Operand x) { + Class type = x.type(); + Operand zero = cast(tf, tf.constant(0), type); + Operand inf = cast(tf, tf.constant(Float.POSITIVE_INFINITY), type); + x = tf.clipByValue(x, zero, inf); + return tf.math.sqrt(x); + } + + /** + * Element-wise value clipping. + * + * @param x the Operand to clip + * @param minValue the minimum value + * @param maxValue the maximum value + * @return the operand with clipped values + */ + protected Operand clip(Operand x, double minValue, double maxValue) { + if (x == null) throw new IllegalArgumentException("Operand x must not be null"); + Ops tf = getTF(); + Class type = x.type(); + if (maxValue < minValue) { + double tmp = maxValue; + maxValue = minValue; + minValue = tmp; + } + Operand minValueConstant = cast(tf, tf.constant(minValue), type); + Operand maxValueConstant = cast(tf, tf.constant(maxValue), type); + return tf.clipByValue(x, minValueConstant, maxValueConstant); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java new file mode 100644 index 00000000000..f55a9998ff0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java @@ -0,0 +1,115 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Constrains the weights incident to each hidden unit to have a norm less than or equal to a + * desired value. + * + * @param the data type for the weights + */ +public class MaxNorm extends Constraint { + public static final float MAX_VALUE_DEFAULT = 2.0f; + public static final int AXIS_DEFAULT = 0; + + /** the maximum norm for the incoming weights. */ + private final float maxValue; + /** integer, axis along which to calculate weight norms. */ + private final int[] axes; + + /** + * Create a MaxNorm constraint using {@link #MAX_VALUE_DEFAULT} for the max value and {@link + * #AXIS_DEFAULT} for the axis. + * + * @param tf the TensorFlow Ops + */ + public MaxNorm(Ops tf) { + this(tf, MAX_VALUE_DEFAULT, AXIS_DEFAULT); + } + + /** + * Create a MaxNorm constraint using {@link #AXIS_DEFAULT} for the axis. + * + * @param tf the TensorFlow Ops + * @param maxValue the maximum norm for the incoming weights. + */ + public MaxNorm(Ops tf, float maxValue) { + this(tf, maxValue, AXIS_DEFAULT); + } + + /** + * Create a MaxNorm constraint + * + * @param tf the TensorFlow Ops + * @param maxValue the maximum norm for the incoming weights. + * @param axis axis along which to calculate weight norms. + */ + public MaxNorm(Ops tf, float maxValue, int axis) { + this(tf, maxValue, new int[] {axis}); + } + + /** + * Create a MaxNorm constraint + * + * @param tf the TensorFlow Ops + * @param maxValue the maximum norm for the incoming weights. + * @param axes axes along which to calculate weight norms. + */ + public MaxNorm(Ops tf, float maxValue, int[] axes) { + super(tf); + this.maxValue = maxValue; + this.axes = axes; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand weights) { + Ops tf = getTF(); + Class type = weights.type(); + Operand norms = + sqrt( + tf.reduceSum( + tf.math.square(weights), tf.constant(getAxes()), ReduceSum.keepDims(Boolean.TRUE))); + Operand desired = clip(norms, 0f, this.getMaxValue()); + + return tf.math.mul( + weights, tf.math.div(desired, tf.math.add(cast(tf, tf.constant(EPSILON), type), norms))); + } + + /** + * Gets the max value + * + * @return the maxValue + */ + public float getMaxValue() { + return maxValue; + } + + /** + * Gets the axes + * + * @return the axes + */ + public int[] getAxes() { + return axes; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java new file mode 100644 index 00000000000..8388d651225 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java @@ -0,0 +1,162 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Constrains the weights to have the norm between a lower bound and an upper bound. + * + * @param the data type for the weights + */ +public class MinMaxNorm extends Constraint { + public static final float MIN_VALUE_DEFAULT = 0.0F; + public static final float MAX_VALUE_DEFAULT = 1.0F; + public static final float RATE_DEFAULT = 1.0F; + public static final int AXIS_DEFAULT = 0; + + /** the minimum norm for the incoming weights. */ + private final float minValue; + /** the maximum norm for the incoming weights. */ + private final float maxValue; + + /** + * rate for enforcing the constraint: weights will be rescaled to yield (1 - rate) * norm + rate * + * norm.clip(min_value, max_value). Effectively, this means that rate=1.0 stands for strict + * enforcement of the constraint, while rate<1.0 means that weights will be rescaled at each step + * to slowly move towards a value inside the desired interval. + */ + private final float rate; + + /** axis along which to calculate weight norms. */ + private final int[] axes; + + /** + * Create a MaxNorm constraint using {@link #MIN_VALUE_DEFAULT} for the min value, {@link + * #MAX_VALUE_DEFAULT} for the max value, {@link #RATE_DEFAULT} for the rate and {@link + * #AXIS_DEFAULT} for the axis + * + * @param tf the TensorFlow Ops + */ + public MinMaxNorm(Ops tf) { + this(tf, MIN_VALUE_DEFAULT, MAX_VALUE_DEFAULT, RATE_DEFAULT, AXIS_DEFAULT); + } + + /** + * Create a MaxNorm constraint using {@link #RATE_DEFAULT} for the rate and {@link #AXIS_DEFAULT} + * for the axis + * + * @param tf the TensorFlow Ops + * @param minValue the minimum norm for the incoming weights. + * @param maxValue the maximum norm for the incoming weights. + */ + public MinMaxNorm(Ops tf, float minValue, float maxValue) { + this(tf, minValue, maxValue, RATE_DEFAULT, AXIS_DEFAULT); + } + + /** + * Create a MaxNorm constraint + * + * @param tf the TensorFlow Ops + * @param minValue the minimum norm for the incoming weights. + * @param maxValue the maximum norm for the incoming weights. + * @param rate the rate for enforcing the constraint. + * @param axis integer, axis along which to calculate weight norms. + */ + public MinMaxNorm(Ops tf, float minValue, float maxValue, float rate, int axis) { + this(tf, minValue, maxValue, rate, new int[] {axis}); + } + /** + * Create a MaxNorm constraint + * + * @param tf the TensorFlow Ops + * @param minValue the minimum norm for the incoming weights. + * @param maxValue the maximum norm for the incoming weights. + * @param rate the rate for enforcing the constraint. + * @param axes integer, axis along which to calculate weight norms. + */ + public MinMaxNorm(Ops tf, float minValue, float maxValue, float rate, int[] axes) { + super(tf); + this.minValue = minValue; + this.maxValue = maxValue; + this.rate = rate; + this.axes = axes; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand weights) { + Class type = weights.type(); + Ops tf = getTF(); + Operand norms = + sqrt( + tf.reduceSum( + tf.math.square(weights), tf.constant(getAxes()), ReduceSum.keepDims(Boolean.TRUE))); + Operand desired = + tf.math.add( + tf.math.mul( + tf.dtypes.cast(tf.constant(this.getRate()), type), + clip(norms, this.getMinValue(), this.getMaxValue())), + tf.math.mul( + tf.math.sub( + tf.dtypes.cast(tf.constant(1), type), + tf.dtypes.cast(tf.constant(this.getRate()), type)), + norms)); + + return tf.math.mul( + weights, tf.math.div(desired, tf.math.add(cast(tf, tf.constant(EPSILON), type), norms))); + } + + /** + * Gets the minValue + * + * @return the minValue + */ + public float getMinValue() { + return minValue; + } + + /** + * Gets the maxValue + * + * @return the maxValue + */ + public float getMaxValue() { + return maxValue; + } + + /** + * Gets the rate + * + * @return the rate + */ + public float getRate() { + return rate; + } + + /** + * Gets the axes + * + * @return the axes + */ + public int[] getAxes() { + return axes; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java new file mode 100644 index 00000000000..3edfa1c036b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Constrains the weights to be non-negative. + * + * @param the data type for the weights + */ +public class NonNeg extends Constraint { + + /** + * Create a NonNeg constraint + * + * @param tf the TensorFlow Ops + */ + public NonNeg(Ops tf) { + super(tf); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand weights) { + Ops tf = getTF(); + Class type = weights.type(); + return tf.math.mul( + weights, + tf.dtypes.cast(tf.math.greaterEqual(weights, tf.dtypes.cast(tf.constant(0), type)), type)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java new file mode 100644 index 00000000000..4eba2fd98c0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Constrains the weights to have unit norm. + * + * @param the data type for the weights + */ +public class UnitNorm extends Constraint { + public static final int AXIS_DEFAULT = 0; + + /** integer, axis along which to calculate weight norms. */ + private final int[] axes; + + /** + * Create a UnitNorm Constraint with the axis set to {@link #AXIS_DEFAULT} + * + * @param tf the TensorFlow Ops + */ + public UnitNorm(Ops tf) { + this(tf, AXIS_DEFAULT); + } + + /** + * Create a UnitNorm Constraint + * + * @param tf the TensorFlow Ops + * @param axis axis along which to calculate weight norms. + */ + public UnitNorm(Ops tf, int axis) { + this(tf, new int[] {axis}); + } + + /** + * Create a UnitNorm Constraint + * + * @param tf the TensorFlow Ops + * @param axes axes along which to calculate weight norms. + */ + public UnitNorm(Ops tf, int[] axes) { + super(tf); + this.axes = axes; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand weights) { + Class type = weights.type(); + + Ops tf = getTF(); + return tf.math.div( + weights, + tf.math.add( + cast(tf, tf.constant(EPSILON), type), + sqrt( + tf.reduceSum( + tf.math.square(weights), + tf.constant(getAxes()), + ReduceSum.keepDims(Boolean.TRUE))))); + } + + /** + * Gets the axes + * + * @return the axes + */ + public int[] getAxes() { + return axes; + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java new file mode 100644 index 00000000000..fa61e097b42 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java @@ -0,0 +1,65 @@ +package org.tensorflow.framework.constraints; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +class MaxNormTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + private float[] getSampleArray() { + Random rand = new Random(3537L); + float[] result = new float[100 * 100]; + for (int i = 0; i < result.length; i++) { + result[i] = rand.nextFloat() * 100 - 50; + } + result[0] = 0; + return result; + } + + /** Test of call method, of class MaxNorm. */ + @Test + public void testCall() { + float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + final float[] array = getSampleArray(); + Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); + for (AtomicInteger i = new AtomicInteger(); + i.get() < testValues.length; + i.getAndIncrement()) { + MaxNorm instance = new MaxNorm<>(tf, testValues[i.get()]); + Operand result = instance.call(weights); + session.evaluate(result, (Number v) -> v.floatValue() <= testValues[i.get()]); + } + } + } + /** Test of call method, of class MaxNorm. */ + @Test + public void testCall1() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MaxNorm instance = new MaxNorm<>(tf, 2.0f); + Operand weights = + tf.constant( + new float[][] { + {0, 1, 3, 3}, {0, 0, 0, 3}, {0, 0, 0, 3}, + }); + Operand result = instance.call(weights); + float[] expected = { + 0, 1, 2, 1.1547005f, + 0, 0, 0, 1.1547005f, + 0, 0, 0, 1.1547005f + }; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java new file mode 100644 index 00000000000..70bae6b9c83 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java @@ -0,0 +1,65 @@ +package org.tensorflow.framework.constraints; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.ND; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +class MinMaxNormTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + private float[] getSampleArray() { + Random rand = new Random(3537L); + float[] result = new float[100 * 100]; + for (int i = 0; i < result.length; i++) { + result[i] = rand.nextFloat() * 100 - 50; + } + result[0] = 0; + return result; + } + + /** Test of call method, of class MinMaxNorm. */ + @Test + public void testCall() { + float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + final float[] array = getSampleArray(); + Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); + for (AtomicInteger i = new AtomicInteger(); + i.get() < testValues.length; + i.getAndIncrement()) { + MinMaxNorm instance = + new MinMaxNorm<>(tf, testValues[i.get()], testValues[i.get()] * 2); + Operand result = instance.call(weights); + if (tfMode == TestSession.Mode.EAGER) + evaluate(session, result.asTensor(), testValues[i.get()]); + else + try (TFloat32 tensor = + (TFloat32) session.getGraphSession().runner().fetch(result).run().get(0)) { + evaluate(session, tensor, testValues[i.get()]); + } + } + } + } + + private void evaluate(TestSession session, TFloat32 tensor, float m) { + FloatNdArray tensorArray = NdArrays.ofFloats(tensor.shape()); + tensor.copyTo(tensorArray); + tensorArray = ND.square(tensorArray); + FloatNdArray normArray = ND.sum(tensorArray, 0); + FloatNdArray normOfNormalized = ND.sqrt(normArray); + session.evaluate( + normOfNormalized, (f) -> f.floatValue() >= m && f.floatValue() <= m * 2f + 1e-5f); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java new file mode 100644 index 00000000000..3942629d6ed --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java @@ -0,0 +1,110 @@ +package org.tensorflow.framework.constraints; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.ND; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +class NonNegTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + private float[] getSampleArray() { + Random rand = new Random(3537L); + float[] result = new float[100 * 100]; + for (int i = 0; i < result.length; i++) { + result[i] = rand.nextFloat() * 100 - 50; + } + result[0] = 0; + return result; + } + + private double[] getSampleDArray() { + Random rand = new Random(3537L); + double[] result = new double[100 * 100]; + for (int i = 0; i < result.length; i++) { + result[i] = rand.nextDouble() * 100 - 50; + } + result[0] = 0; + return result; + } + + @Test + public void testTFloat32() { + float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + final float[] array = getSampleArray(); + Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); + for (AtomicInteger i = new AtomicInteger(); + i.get() < testValues.length; + i.getAndIncrement()) { + MinMaxNorm instance = + new MinMaxNorm<>(tf, testValues[i.get()], testValues[i.get()] * 2); + Operand result = instance.call(weights); + if (tfMode == TestSession.Mode.EAGER) + evaluate(session, result.asTensor(), testValues[i.get()]); + else + try (TFloat32 tensor = + (TFloat32) session.getGraphSession().runner().fetch(result).run().get(0)) { + evaluate(session, tensor, testValues[i.get()]); + } + } + } + } + + @Test + public void testTFloat64() { + float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + final double[] array = getSampleDArray(); + Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); + for (AtomicInteger i = new AtomicInteger(); + i.get() < testValues.length; + i.getAndIncrement()) { + MinMaxNorm instance = + new MinMaxNorm<>(tf, testValues[i.get()], testValues[i.get()] * 2); + Operand result = instance.call(weights); + if (tfMode == TestSession.Mode.EAGER) + evaluate(session, result.asTensor(), testValues[i.get()]); + else + try (TFloat64 tensor = + (TFloat64) session.getGraphSession().runner().fetch(result).run().get(0)) { + evaluate(session, tensor, testValues[i.get()]); + } + } + } + } + + private void evaluate(TestSession session, TFloat32 tensor, float m) { + FloatNdArray tensorArray = NdArrays.ofFloats(tensor.shape()); + tensor.copyTo(tensorArray); + tensorArray = ND.square(tensorArray); + FloatNdArray normArray = ND.sum(tensorArray, 0); + FloatNdArray normOfNormalized = ND.sqrt(normArray); + session.evaluate( + normOfNormalized, (f) -> f.floatValue() >= m && f.floatValue() <= m * 2f + 1e-5f); + } + + private void evaluate(TestSession session, TFloat64 tensor, float m) { + DoubleNdArray tensorArray = NdArrays.ofDoubles(tensor.shape()); + tensor.copyTo(tensorArray); + tensorArray = ND.square(tensorArray); + DoubleNdArray normArray = ND.sum(tensorArray, 0); + DoubleNdArray normOfNormalized = ND.sqrt(normArray); + session.evaluate( + normOfNormalized, (f) -> f.doubleValue() >= m && f.doubleValue() <= m * 2 + 1e-5); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java new file mode 100644 index 00000000000..7b6359bcf6c --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java @@ -0,0 +1,66 @@ +package org.tensorflow.framework.constraints; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +class UnitNormTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + private float[] getSampleArray() { + Random rand = new Random(3537L); + float[] result = new float[100 * 100]; + for (int i = 0; i < result.length; i++) { + result[i] = rand.nextFloat() * 100 - 50; + } + result[0] = 0; + return result; + } + + /** Test of call method, of class MaxNorm. */ + @Test + public void testTFloat32() { + float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + final float[] array = getSampleArray(); + Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); + for (AtomicInteger i = new AtomicInteger(); + i.get() < testValues.length; + i.getAndIncrement()) { + MaxNorm instance = new MaxNorm<>(tf, testValues[i.get()]); + Operand result = instance.call(weights); + session.evaluate(result, (Number v) -> v.floatValue() <= testValues[i.get()]); + } + } + } + /** Test of call method, of class MaxNorm. */ + @Test + public void testCallTFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MaxNorm instance = new MaxNorm<>(tf, 2.0f); + Operand weights = + tf.constant( + new double[][] { + {0, 1, 3, 3}, {0, 0, 0, 3}, {0, 0, 0, 3}, + }); + Operand result = instance.call(weights); + double[] expected = { + 0, 1, 2, 1.1547005, + 0, 0, 0, 1.1547005, + 0, 0, 0, 1.1547005 + }; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java index 0503a41dfc2..ef8bb71d724 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java @@ -14,10 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.*; import java.util.Arrays; import java.util.concurrent.atomic.AtomicBoolean; @@ -103,6 +100,23 @@ public static FloatNdArray sqrt(FloatNdArray a) { return result; } + /** + * Gets the square root of an array. + * + * @param a the array + * @return the square root of the array. + */ + public static DoubleNdArray sqrt(DoubleNdArray a) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + int nDims = a.shape().numDimensions(); + a.elements(nDims - 1) + .forEachIndexed( + (idx, v) -> { + result.setDouble(Math.sqrt(v.getDouble()), idx); + }); + return result; + } + /** * Gets the square of an array. * @@ -120,6 +134,23 @@ public static FloatNdArray square(FloatNdArray a) { return result; } + /** + * Gets the square of an array. + * + * @param a the array + * @return the square of the array. + */ + public static DoubleNdArray square(DoubleNdArray a) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + int nDims = a.shape().numDimensions(); + a.elements(nDims - 1) + .forEachIndexed( + (idx, v) -> { + result.setDouble(v.getDouble() * v.getDouble(), idx); + }); + return result; + } + /** * Adds two arrays * @@ -568,6 +599,18 @@ public static FloatNdArray sum(FloatNdArray a) { return NdArrays.scalarOf(sum.get()); } + /** + * Sum all elements of an array + * + * @param a the array + * @return an a array with one element containing the sum. + */ + public static DoubleNdArray sum(DoubleNdArray a) { + AtomicReference sum = new AtomicReference(0D); + a.scalars().forEach(f -> sum.set(sum.get() + f.getDouble())); + return NdArrays.scalarOf(sum.get()); + } + /** * Sum all elements of an array based on the specified axis * @@ -579,6 +622,17 @@ public static FloatNdArray sum(FloatNdArray a, int axis) { return sum(a, axis, false); } + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axis the axis to sum + * @return an a array the sum over the axis less the diemsnion + */ + public static DoubleNdArray sum(DoubleNdArray a, int axis) { + return sum(a, axis, false); + } + /** * Sum all elements of an array based on the specified axis * @@ -618,6 +672,45 @@ public static FloatNdArray sum(FloatNdArray a, int axis, boolean keepDims) { } } + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axis the axis to sum + * @param keepDims indicates whether the dimensions over the sum should be kept or not. + * @return an a array the sum over the axis + */ + public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { + Shape shape = a.shape(); + int nDims = shape.numDimensions(); + int xis = nDims - 1 - axis; + long totalSize = shape.size(); + long axisSize = shape.size(xis); + final double[] sums = new double[(int) axisSize]; + + a.scalars() + .forEachIndexed( + (idx, f) -> { + sums[(int) idx[xis]] += f.getDouble(); + }); + + if (keepDims) { + long[] newDims = shape.asArray(); + newDims[axis] = 1; + final AtomicInteger counter = new AtomicInteger(); + DoubleNdArray arrayK = NdArrays.ofDoubles(Shape.of(newDims)); + arrayK + .elements(newDims.length - 1) + .forEachIndexed( + (idx, v) -> { + v.setDouble(sums[counter.getAndAdd(1)]); + }); + return arrayK; + } else { + return NdArrays.vectorOf(sums); + } + } + /** * Sum all elements of an array based on the specified axis * diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java index db39a330522..2c252d467c7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.utils; import org.tensorflow.*; +import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -493,15 +494,16 @@ public void evaluate(FloatNdArray input, Predicate predicate) { } /** - * Print the input to standard out + * Evaluates the input against the expected value * - - * @param input the operand to print - * @param the data type of the input + * @param input the operand to evaluate + * @param predicate The Predicate that evaluates the each value from input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void print(Operand input) { - print(new PrintWriter(new OutputStreamWriter(System.out)), input.asOutput()); + public void evaluate(DoubleNdArray input, Predicate predicate) { + input.scalars().forEach(f -> assertTrue(predicate.test(f.getDouble()))); } + /** * Print the input * From dc9fdb8dd4c263166b255a33011d763b631204c1 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 26 Jan 2021 11:21:36 -0500 Subject: [PATCH 2/3] Clean up JavaDoc Change float attributes to double --- .../framework/constraints/Constraint.java | 22 ++++++++--------- .../framework/constraints/MaxNorm.java | 12 +++++----- .../framework/constraints/MinMaxNorm.java | 24 +++++++++---------- .../framework/constraints/MaxNormTest.java | 6 ++--- 4 files changed, 31 insertions(+), 33 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java index bf6f97b463a..1bcd3bd04ad 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java @@ -58,7 +58,7 @@ public Ops getTF() { } /** - * Get the element-wise square root. + * Gets the element-wise square root. * * @param x the input Operand. * @return the element-wise square root. @@ -66,13 +66,12 @@ public Ops getTF() { protected Operand sqrt(Operand x) { Class type = x.type(); Operand zero = cast(tf, tf.constant(0), type); - Operand inf = cast(tf, tf.constant(Float.POSITIVE_INFINITY), type); - x = tf.clipByValue(x, zero, inf); - return tf.math.sqrt(x); + Operand inf = cast(tf, tf.constant(Double.POSITIVE_INFINITY), type); + return tf.math.sqrt(tf.clipByValue(x, zero, inf)); } /** - * Element-wise value clipping. + * Gets the element-wise value clipping. * * @param x the Operand to clip * @param minValue the minimum value @@ -83,13 +82,12 @@ protected Operand clip(Operand x, double minValue, double maxValue) { if (x == null) throw new IllegalArgumentException("Operand x must not be null"); Ops tf = getTF(); Class type = x.type(); - if (maxValue < minValue) { - double tmp = maxValue; - maxValue = minValue; - minValue = tmp; - } - Operand minValueConstant = cast(tf, tf.constant(minValue), type); - Operand maxValueConstant = cast(tf, tf.constant(maxValue), type); + + double min = Math.min(minValue, maxValue); + double max = Math.max(minValue, maxValue); + + Operand minValueConstant = cast(tf, tf.constant(min), type); + Operand maxValueConstant = cast(tf, tf.constant(max), type); return tf.clipByValue(x, minValueConstant, maxValueConstant); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java index f55a9998ff0..13a7ee9eb16 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java @@ -28,11 +28,11 @@ * @param the data type for the weights */ public class MaxNorm extends Constraint { - public static final float MAX_VALUE_DEFAULT = 2.0f; + public static final double MAX_VALUE_DEFAULT = 2.0; public static final int AXIS_DEFAULT = 0; /** the maximum norm for the incoming weights. */ - private final float maxValue; + private final double maxValue; /** integer, axis along which to calculate weight norms. */ private final int[] axes; @@ -52,7 +52,7 @@ public MaxNorm(Ops tf) { * @param tf the TensorFlow Ops * @param maxValue the maximum norm for the incoming weights. */ - public MaxNorm(Ops tf, float maxValue) { + public MaxNorm(Ops tf, double maxValue) { this(tf, maxValue, AXIS_DEFAULT); } @@ -63,7 +63,7 @@ public MaxNorm(Ops tf, float maxValue) { * @param maxValue the maximum norm for the incoming weights. * @param axis axis along which to calculate weight norms. */ - public MaxNorm(Ops tf, float maxValue, int axis) { + public MaxNorm(Ops tf, double maxValue, int axis) { this(tf, maxValue, new int[] {axis}); } @@ -74,7 +74,7 @@ public MaxNorm(Ops tf, float maxValue, int axis) { * @param maxValue the maximum norm for the incoming weights. * @param axes axes along which to calculate weight norms. */ - public MaxNorm(Ops tf, float maxValue, int[] axes) { + public MaxNorm(Ops tf, double maxValue, int[] axes) { super(tf); this.maxValue = maxValue; this.axes = axes; @@ -100,7 +100,7 @@ public Operand call(Operand weights) { * * @return the maxValue */ - public float getMaxValue() { + public double getMaxValue() { return maxValue; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java index 8388d651225..9cc39bfcf99 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java @@ -27,15 +27,15 @@ * @param the data type for the weights */ public class MinMaxNorm extends Constraint { - public static final float MIN_VALUE_DEFAULT = 0.0F; - public static final float MAX_VALUE_DEFAULT = 1.0F; - public static final float RATE_DEFAULT = 1.0F; + public static final double MIN_VALUE_DEFAULT = 0.0; + public static final double MAX_VALUE_DEFAULT = 1.0; + public static final double RATE_DEFAULT = 1.0; public static final int AXIS_DEFAULT = 0; /** the minimum norm for the incoming weights. */ - private final float minValue; + private final double minValue; /** the maximum norm for the incoming weights. */ - private final float maxValue; + private final double maxValue; /** * rate for enforcing the constraint: weights will be rescaled to yield (1 - rate) * norm + rate * @@ -43,7 +43,7 @@ public class MinMaxNorm extends Constraint { * enforcement of the constraint, while rate<1.0 means that weights will be rescaled at each step * to slowly move towards a value inside the desired interval. */ - private final float rate; + private final double rate; /** axis along which to calculate weight norms. */ private final int[] axes; @@ -67,7 +67,7 @@ public MinMaxNorm(Ops tf) { * @param minValue the minimum norm for the incoming weights. * @param maxValue the maximum norm for the incoming weights. */ - public MinMaxNorm(Ops tf, float minValue, float maxValue) { + public MinMaxNorm(Ops tf, double minValue, double maxValue) { this(tf, minValue, maxValue, RATE_DEFAULT, AXIS_DEFAULT); } @@ -80,7 +80,7 @@ public MinMaxNorm(Ops tf, float minValue, float maxValue) { * @param rate the rate for enforcing the constraint. * @param axis integer, axis along which to calculate weight norms. */ - public MinMaxNorm(Ops tf, float minValue, float maxValue, float rate, int axis) { + public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int axis) { this(tf, minValue, maxValue, rate, new int[] {axis}); } /** @@ -92,7 +92,7 @@ public MinMaxNorm(Ops tf, float minValue, float maxValue, float rate, int axis) * @param rate the rate for enforcing the constraint. * @param axes integer, axis along which to calculate weight norms. */ - public MinMaxNorm(Ops tf, float minValue, float maxValue, float rate, int[] axes) { + public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int[] axes) { super(tf); this.minValue = minValue; this.maxValue = maxValue; @@ -129,7 +129,7 @@ public Operand call(Operand weights) { * * @return the minValue */ - public float getMinValue() { + public double getMinValue() { return minValue; } @@ -138,7 +138,7 @@ public float getMinValue() { * * @return the maxValue */ - public float getMaxValue() { + public double getMaxValue() { return maxValue; } @@ -147,7 +147,7 @@ public float getMaxValue() { * * @return the rate */ - public float getRate() { + public double getRate() { return rate; } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java index fa61e097b42..08d693c9432 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java @@ -26,7 +26,7 @@ private float[] getSampleArray() { /** Test of call method, of class MaxNorm. */ @Test public void testCall() { - float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; + double[] testValues = {0.1, 0.5, 3, 8, 1e-7}; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); @@ -37,7 +37,7 @@ public void testCall() { i.getAndIncrement()) { MaxNorm instance = new MaxNorm<>(tf, testValues[i.get()]); Operand result = instance.call(weights); - session.evaluate(result, (Number v) -> v.floatValue() <= testValues[i.get()]); + session.evaluate(result, v -> v.floatValue() <= testValues[i.get()]); } } } @@ -47,7 +47,7 @@ public void testCall1() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MaxNorm instance = new MaxNorm<>(tf, 2.0f); + MaxNorm instance = new MaxNorm<>(tf, 2.0); Operand weights = tf.constant( new float[][] { From c04eeb6b5781fb11725aacc3a960e5daea7a3f03 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Tue, 23 Feb 2021 14:14:02 -0500 Subject: [PATCH 3/3] Refactor Constraint to only have Generic parameter on call method. Add norm method on Constraint that is leveraged by the xxxxNorm constraints. Fix unit test cases to properly test the actual classes (oops). Fix Javadoc --- .../annotations/org/tensorflow/op/Ops.java | 13 +-- .../framework/constraints/Constraint.java | 35 +++++-- .../framework/constraints/MaxNorm.java | 12 +-- .../framework/constraints/MinMaxNorm.java | 26 ++--- .../framework/constraints/NonNeg.java | 10 +- .../framework/constraints/UnitNorm.java | 20 +--- .../framework/constraints/MaxNormTest.java | 4 +- .../framework/constraints/MinMaxNormTest.java | 3 +- .../framework/constraints/NonNegTest.java | 94 +++---------------- .../framework/constraints/UnitNormTest.java | 68 ++++++-------- 10 files changed, 98 insertions(+), 187 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 529b0d99c39..8ac0e66a9e5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -1834,13 +1834,14 @@ public Constant constant(Shape shape, IntDataBuffer data) { } /** - * Creates a scalar of {@code type}, with the value of {@code number}. - * {@code number} may be truncated if it does not fit in the target type. + * Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not + * fit in the target type. * * @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating}) * @param number the value of the tensor * @return a constant of the passed type - * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or unknown. + * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or + * unknown. */ public Constant constant(Class type, Number number) { return Constant.tensorOf(scope, type, number); @@ -1892,14 +1893,14 @@ public Constant constantOf(T tensor) { } /** - * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. - * {@code number} may be truncated if it does not fit in the target type. + * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be + * truncated if it does not fit in the target type. * * @param toMatch the operand providing the target type * @param number the value of the tensor * @return a constant with the same type as {@code toMatch} - * @see Ops#constant(Class, Number) * @throws IllegalArgumentException if the type is unknown (which should be impossible). + * @see Ops#constant(Class, Number) */ public Constant constantOfSameType(Operand toMatch, Number number) { return Constant.tensorOfSameType(scope, toMatch, number); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java index 1bcd3bd04ad..d3094b5e9e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java @@ -16,16 +16,13 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.family.TNumber; import static org.tensorflow.framework.utils.CastHelper.cast; -/** - * Base class for Constraints. Constraint subclasses impose constraints on weight values - * - * @param the date type for the weights - */ -public abstract class Constraint { +/** Base class for Constraints. Constraint subclasses impose constraints on weight values */ +public abstract class Constraint { public static final float EPSILON = 1e-7f; @@ -46,7 +43,7 @@ public Constraint(Ops tf) { * @param weights the weights * @return the constrained weights */ - public abstract Operand call(Operand weights); + public abstract Operand call(Operand weights); /** * Gets the TensorFlow Ops @@ -62,8 +59,11 @@ public Ops getTF() { * * @param x the input Operand. * @return the element-wise square root. + * @param The data type for the operand and result. + * @throws IllegalArgumentException if x is null */ - protected Operand sqrt(Operand x) { + protected Operand sqrt(Operand x) { + if (x == null) throw new IllegalArgumentException("Operand x must not be null"); Class type = x.type(); Operand zero = cast(tf, tf.constant(0), type); Operand inf = cast(tf, tf.constant(Double.POSITIVE_INFINITY), type); @@ -77,8 +77,10 @@ protected Operand sqrt(Operand x) { * @param minValue the minimum value * @param maxValue the maximum value * @return the operand with clipped values + * @param The data type for the operand and result. + * @throws IllegalArgumentException if x is null */ - protected Operand clip(Operand x, double minValue, double maxValue) { + protected Operand clip(Operand x, double minValue, double maxValue) { if (x == null) throw new IllegalArgumentException("Operand x must not be null"); Ops tf = getTF(); Class type = x.type(); @@ -90,4 +92,19 @@ protected Operand clip(Operand x, double minValue, double maxValue) { Operand maxValueConstant = cast(tf, tf.constant(max), type); return tf.clipByValue(x, minValueConstant, maxValueConstant); } + + /** + * Calculates the norm of the weights along the axes + * + * @param weights the weights used to calculate the norms + * @param axes the axes along which to calculate weight norms. + * @param the data type for the weights and the result + * @return the norms + * @throws IllegalArgumentException if weights is null + */ + protected Operand norm(Operand weights, int[] axes) { + if (weights == null) throw new IllegalArgumentException("weights must not be null"); + return sqrt( + tf.reduceSum(tf.math.square(weights), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE))); + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java index 13a7ee9eb16..1dae117b113 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java @@ -16,7 +16,6 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.family.TNumber; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -24,10 +23,8 @@ /** * Constrains the weights incident to each hidden unit to have a norm less than or equal to a * desired value. - * - * @param the data type for the weights */ -public class MaxNorm extends Constraint { +public class MaxNorm extends Constraint { public static final double MAX_VALUE_DEFAULT = 2.0; public static final int AXIS_DEFAULT = 0; @@ -82,13 +79,10 @@ public MaxNorm(Ops tf, double maxValue, int[] axes) { /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { + public Operand call(Operand weights) { Ops tf = getTF(); Class type = weights.type(); - Operand norms = - sqrt( - tf.reduceSum( - tf.math.square(weights), tf.constant(getAxes()), ReduceSum.keepDims(Boolean.TRUE))); + Operand norms = norm(weights, getAxes()); Operand desired = clip(norms, 0f, this.getMaxValue()); return tf.math.mul( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java index 9cc39bfcf99..04b21572e55 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java @@ -16,17 +16,12 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.family.TNumber; import static org.tensorflow.framework.utils.CastHelper.cast; -/** - * Constrains the weights to have the norm between a lower bound and an upper bound. - * - * @param the data type for the weights - */ -public class MinMaxNorm extends Constraint { +/** Constrains the weights to have the norm between a lower bound and an upper bound. */ +public class MinMaxNorm extends Constraint { public static final double MIN_VALUE_DEFAULT = 0.0; public static final double MAX_VALUE_DEFAULT = 1.0; public static final double RATE_DEFAULT = 1.0; @@ -49,7 +44,7 @@ public class MinMaxNorm extends Constraint { private final int[] axes; /** - * Create a MaxNorm constraint using {@link #MIN_VALUE_DEFAULT} for the min value, {@link + * Create a MinMaxNorm constraint using {@link #MIN_VALUE_DEFAULT} for the min value, {@link * #MAX_VALUE_DEFAULT} for the max value, {@link #RATE_DEFAULT} for the rate and {@link * #AXIS_DEFAULT} for the axis * @@ -60,8 +55,8 @@ public MinMaxNorm(Ops tf) { } /** - * Create a MaxNorm constraint using {@link #RATE_DEFAULT} for the rate and {@link #AXIS_DEFAULT} - * for the axis + * Create a MinMaxNorm constraint using {@link #RATE_DEFAULT} for the rate and {@link + * #AXIS_DEFAULT} for the axis * * @param tf the TensorFlow Ops * @param minValue the minimum norm for the incoming weights. @@ -72,7 +67,7 @@ public MinMaxNorm(Ops tf, double minValue, double maxValue) { } /** - * Create a MaxNorm constraint + * Create a MinMaxNorm constraint * * @param tf the TensorFlow Ops * @param minValue the minimum norm for the incoming weights. @@ -84,7 +79,7 @@ public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int axi this(tf, minValue, maxValue, rate, new int[] {axis}); } /** - * Create a MaxNorm constraint + * Create a MinMaxNorm constraint * * @param tf the TensorFlow Ops * @param minValue the minimum norm for the incoming weights. @@ -102,13 +97,10 @@ public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int[] a /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { + public Operand call(Operand weights) { Class type = weights.type(); Ops tf = getTF(); - Operand norms = - sqrt( - tf.reduceSum( - tf.math.square(weights), tf.constant(getAxes()), ReduceSum.keepDims(Boolean.TRUE))); + Operand norms = norm(weights, getAxes()); Operand desired = tf.math.add( tf.math.mul( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java index 3edfa1c036b..0194b2fadb6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java @@ -18,12 +18,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * Constrains the weights to be non-negative. - * - * @param the data type for the weights - */ -public class NonNeg extends Constraint { +/** Constrains the weights to be non-negative. */ +public class NonNeg extends Constraint { /** * Create a NonNeg constraint @@ -36,7 +32,7 @@ public NonNeg(Ops tf) { /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { + public Operand call(Operand weights) { Ops tf = getTF(); Class type = weights.type(); return tf.math.mul( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java index 4eba2fd98c0..70bb1a59785 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java @@ -16,17 +16,12 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.ReduceSum; import org.tensorflow.types.family.TNumber; import static org.tensorflow.framework.utils.CastHelper.cast; -/** - * Constrains the weights to have unit norm. - * - * @param the data type for the weights - */ -public class UnitNorm extends Constraint { +/** Constrains the weights to have unit norm. */ +public class UnitNorm extends Constraint { public static final int AXIS_DEFAULT = 0; /** integer, axis along which to calculate weight norms. */ @@ -64,19 +59,12 @@ public UnitNorm(Ops tf, int[] axes) { /** {@inheritDoc} */ @Override - public Operand call(Operand weights) { + public Operand call(Operand weights) { Class type = weights.type(); Ops tf = getTF(); return tf.math.div( - weights, - tf.math.add( - cast(tf, tf.constant(EPSILON), type), - sqrt( - tf.reduceSum( - tf.math.square(weights), - tf.constant(getAxes()), - ReduceSum.keepDims(Boolean.TRUE))))); + weights, tf.math.add(cast(tf, tf.constant(EPSILON), type), norm(weights, getAxes()))); } /** diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java index 08d693c9432..1f80388e88f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java @@ -35,7 +35,7 @@ public void testCall() { for (AtomicInteger i = new AtomicInteger(); i.get() < testValues.length; i.getAndIncrement()) { - MaxNorm instance = new MaxNorm<>(tf, testValues[i.get()]); + MaxNorm instance = new MaxNorm(tf, testValues[i.get()]); Operand result = instance.call(weights); session.evaluate(result, v -> v.floatValue() <= testValues[i.get()]); } @@ -47,7 +47,7 @@ public void testCall1() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MaxNorm instance = new MaxNorm<>(tf, 2.0); + MaxNorm instance = new MaxNorm(tf, 2.0); Operand weights = tf.constant( new float[][] { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java index 70bae6b9c83..8c2c3a54ff9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java @@ -39,8 +39,7 @@ public void testCall() { for (AtomicInteger i = new AtomicInteger(); i.get() < testValues.length; i.getAndIncrement()) { - MinMaxNorm instance = - new MinMaxNorm<>(tf, testValues[i.get()], testValues[i.get()] * 2); + MinMaxNorm instance = new MinMaxNorm(tf, testValues[i.get()], testValues[i.get()] * 2); Operand result = instance.call(weights); if (tfMode == TestSession.Mode.EAGER) evaluate(session, result.asTensor(), testValues[i.get()]); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java index 3942629d6ed..6a6fdc13536 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java @@ -2,109 +2,39 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; -import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.ndarray.DoubleNdArray; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import java.util.Random; -import java.util.concurrent.atomic.AtomicInteger; - class NonNegTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - private float[] getSampleArray() { - Random rand = new Random(3537L); - float[] result = new float[100 * 100]; - for (int i = 0; i < result.length; i++) { - result[i] = rand.nextFloat() * 100 - 50; - } - result[0] = 0; - return result; - } - - private double[] getSampleDArray() { - Random rand = new Random(3537L); - double[] result = new double[100 * 100]; - for (int i = 0; i < result.length; i++) { - result[i] = rand.nextDouble() * 100 - 50; - } - result[0] = 0; - return result; - } - @Test public void testTFloat32() { - float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - final float[] array = getSampleArray(); - Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); - for (AtomicInteger i = new AtomicInteger(); - i.get() < testValues.length; - i.getAndIncrement()) { - MinMaxNorm instance = - new MinMaxNorm<>(tf, testValues[i.get()], testValues[i.get()] * 2); - Operand result = instance.call(weights); - if (tfMode == TestSession.Mode.EAGER) - evaluate(session, result.asTensor(), testValues[i.get()]); - else - try (TFloat32 tensor = - (TFloat32) session.getGraphSession().runner().fetch(result).run().get(0)) { - evaluate(session, tensor, testValues[i.get()]); - } - } + float[][] array = {{-1, 2, -3, 4}, {-10, 11, 12, -13}}; + Operand weights = tf.constant(array); + NonNeg instance = new NonNeg(tf); + Operand result = instance.call(weights); + float[] expected = {0, 2, 0, 4, 0, 11, 12, 0}; + session.evaluate(expected, result); } } @Test public void testTFloat64() { - float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - final double[] array = getSampleDArray(); - Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); - for (AtomicInteger i = new AtomicInteger(); - i.get() < testValues.length; - i.getAndIncrement()) { - MinMaxNorm instance = - new MinMaxNorm<>(tf, testValues[i.get()], testValues[i.get()] * 2); - Operand result = instance.call(weights); - if (tfMode == TestSession.Mode.EAGER) - evaluate(session, result.asTensor(), testValues[i.get()]); - else - try (TFloat64 tensor = - (TFloat64) session.getGraphSession().runner().fetch(result).run().get(0)) { - evaluate(session, tensor, testValues[i.get()]); - } - } + final double[][] array = {{-1, 2, -3, 4}, {-10, 11, 12, -13}}; + Operand weights = tf.constant(array); + NonNeg instance = new NonNeg(tf); + Operand result = instance.call(weights); + double[] expected = {0, 2, 0, 4, 0, 11, 12, 0}; + session.evaluate(expected, result); } } - - private void evaluate(TestSession session, TFloat32 tensor, float m) { - FloatNdArray tensorArray = NdArrays.ofFloats(tensor.shape()); - tensor.copyTo(tensorArray); - tensorArray = ND.square(tensorArray); - FloatNdArray normArray = ND.sum(tensorArray, 0); - FloatNdArray normOfNormalized = ND.sqrt(normArray); - session.evaluate( - normOfNormalized, (f) -> f.floatValue() >= m && f.floatValue() <= m * 2f + 1e-5f); - } - - private void evaluate(TestSession session, TFloat64 tensor, float m) { - DoubleNdArray tensorArray = NdArrays.ofDoubles(tensor.shape()); - tensor.copyTo(tensorArray); - tensorArray = ND.square(tensorArray); - DoubleNdArray normArray = ND.sum(tensorArray, 0); - DoubleNdArray normOfNormalized = ND.sqrt(normArray); - session.evaluate( - normOfNormalized, (f) -> f.doubleValue() >= m && f.doubleValue() <= m * 2 + 1e-5); - } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java index 7b6359bcf6c..6437ebcd760 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java @@ -3,63 +3,57 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import java.util.Random; -import java.util.concurrent.atomic.AtomicInteger; - class UnitNormTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - private float[] getSampleArray() { - Random rand = new Random(3537L); - float[] result = new float[100 * 100]; - for (int i = 0; i < result.length; i++) { - result[i] = rand.nextFloat() * 100 - 50; - } - result[0] = 0; - return result; - } - - /** Test of call method, of class MaxNorm. */ + /** Test of call method, of class UnitNorm. */ @Test public void testTFloat32() { - float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - final float[] array = getSampleArray(); - Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); - for (AtomicInteger i = new AtomicInteger(); - i.get() < testValues.length; - i.getAndIncrement()) { - MaxNorm instance = new MaxNorm<>(tf, testValues[i.get()]); - Operand result = instance.call(weights); - session.evaluate(result, (Number v) -> v.floatValue() <= testValues[i.get()]); - } + float[][][] array = { + {{0.14517927f, 0.2574964f, 0.2291325f}, {0.9145494f, 0.9378068f, 0.6827883f}}, + {{0.27121753f, 0.08317473f, 0.3770739f}, {0.25451255f, 0.18511271f, 0.5620538f}}, + {{0.40101776f, 0.25205433f, 0.05103926f}, {0.08764106f, 0.00593294f, 0.37244815f}} + }; + float[][][] expectedArray = { + {{0.1567809f, 0.2647736f, 0.31814702f}, {0.9876333f, 0.9643105f, 0.94804124f}}, + {{0.72920675f, 0.40984813f, 0.55712338f}, {0.68429305f, 0.91215323f, 0.83042956f}}, + {{0.97694125f, 0.99972269f, 0.13576831f}, {0.21350717f, 0.02353181f, 0.99074035f}} + }; + + Operand weights = tf.constant(array); + UnitNorm instance = new UnitNorm(tf, 1); + Operand result = instance.call(weights); + Operand expected = tf.constant(expectedArray); + session.evaluate(expected, result); } } - /** Test of call method, of class MaxNorm. */ + /** Test of call method, of class UnitNorm. */ @Test public void testCallTFloat64() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - MaxNorm instance = new MaxNorm<>(tf, 2.0f); - Operand weights = - tf.constant( - new double[][] { - {0, 1, 3, 3}, {0, 0, 0, 3}, {0, 0, 0, 3}, - }); - Operand result = instance.call(weights); - double[] expected = { - 0, 1, 2, 1.1547005, - 0, 0, 0, 1.1547005, - 0, 0, 0, 1.1547005 + double[][][] array = { + {{0.14517927, 0.2574964, 0.2291325}, {0.9145494, 0.9378068, 0.6827883}}, + {{0.27121753, 0.08317473, 0.3770739}, {0.25451255, 0.18511271, 0.5620538}}, + {{0.40101776, 0.25205433, 0.05103926}, {0.08764106, 0.00593294, 0.37244815}} }; + double[][][] expectedArray = { + {{0.1567809, 0.2647736, 0.31814702}, {0.9876333, 0.9643105, 0.94804124}}, + {{0.72920675, 0.40984813, 0.55712338}, {0.68429305, 0.91215323, 0.83042956}}, + {{0.97694125, 0.99972269, 0.13576831}, {0.21350717, 0.02353181, 0.99074035}} + }; + UnitNorm instance = new UnitNorm(tf, 1); + Operand weights = tf.constant(array); + Operand result = instance.call(weights); + Operand expected = tf.constant(expectedArray); session.evaluate(expected, result); } }