diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 529b0d99c39..8ac0e66a9e5 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -1834,13 +1834,14 @@ public Constant constant(Shape shape, IntDataBuffer data) { } /** - * Creates a scalar of {@code type}, with the value of {@code number}. - * {@code number} may be truncated if it does not fit in the target type. + * Creates a scalar of {@code type}, with the value of {@code number}. {@code number} may be truncated if it does not + * fit in the target type. * * @param type the type of tensor to create. Must be concrete (i.e. not {@link org.tensorflow.types.family.TFloating}) * @param number the value of the tensor * @return a constant of the passed type - * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or unknown. + * @throws IllegalArgumentException if the type is abstract (i.e. {@link org.tensorflow.types.family.TFloating}) or + * unknown. */ public Constant constant(Class type, Number number) { return Constant.tensorOf(scope, type, number); @@ -1892,14 +1893,14 @@ public Constant constantOf(T tensor) { } /** - * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. - * {@code number} may be truncated if it does not fit in the target type. + * Creates a scalar of the same type as {@code toMatch}, with the value of {@code number}. {@code number} may be + * truncated if it does not fit in the target type. * * @param toMatch the operand providing the target type * @param number the value of the tensor * @return a constant with the same type as {@code toMatch} - * @see Ops#constant(Class, Number) * @throws IllegalArgumentException if the type is unknown (which should be impossible). + * @see Ops#constant(Class, Number) */ public Constant constantOfSameType(Operand toMatch, Number number) { return Constant.tensorOfSameType(scope, toMatch, number); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java new file mode 100644 index 00000000000..d3094b5e9e9 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java @@ -0,0 +1,110 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** Base class for Constraints. Constraint subclasses impose constraints on weight values */ +public abstract class Constraint { + + public static final float EPSILON = 1e-7f; + + private final Ops tf; + + /** + * Creates a Constraint + * + * @param tf the TensorFlow Ops + */ + public Constraint(Ops tf) { + this.tf = tf; + } + + /** + * Applies the constraint against the provided weights + * + * @param weights the weights + * @return the constrained weights + */ + public abstract Operand call(Operand weights); + + /** + * Gets the TensorFlow Ops + * + * @return the TensorFlow Ops + */ + public Ops getTF() { + return tf; + } + + /** + * Gets the element-wise square root. + * + * @param x the input Operand. + * @return the element-wise square root. + * @param The data type for the operand and result. + * @throws IllegalArgumentException if x is null + */ + protected Operand sqrt(Operand x) { + if (x == null) throw new IllegalArgumentException("Operand x must not be null"); + Class type = x.type(); + Operand zero = cast(tf, tf.constant(0), type); + Operand inf = cast(tf, tf.constant(Double.POSITIVE_INFINITY), type); + return tf.math.sqrt(tf.clipByValue(x, zero, inf)); + } + + /** + * Gets the element-wise value clipping. + * + * @param x the Operand to clip + * @param minValue the minimum value + * @param maxValue the maximum value + * @return the operand with clipped values + * @param The data type for the operand and result. + * @throws IllegalArgumentException if x is null + */ + protected Operand clip(Operand x, double minValue, double maxValue) { + if (x == null) throw new IllegalArgumentException("Operand x must not be null"); + Ops tf = getTF(); + Class type = x.type(); + + double min = Math.min(minValue, maxValue); + double max = Math.max(minValue, maxValue); + + Operand minValueConstant = cast(tf, tf.constant(min), type); + Operand maxValueConstant = cast(tf, tf.constant(max), type); + return tf.clipByValue(x, minValueConstant, maxValueConstant); + } + + /** + * Calculates the norm of the weights along the axes + * + * @param weights the weights used to calculate the norms + * @param axes the axes along which to calculate weight norms. + * @param the data type for the weights and the result + * @return the norms + * @throws IllegalArgumentException if weights is null + */ + protected Operand norm(Operand weights, int[] axes) { + if (weights == null) throw new IllegalArgumentException("weights must not be null"); + return sqrt( + tf.reduceSum(tf.math.square(weights), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE))); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java new file mode 100644 index 00000000000..1dae117b113 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MaxNorm.java @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Constrains the weights incident to each hidden unit to have a norm less than or equal to a + * desired value. + */ +public class MaxNorm extends Constraint { + public static final double MAX_VALUE_DEFAULT = 2.0; + public static final int AXIS_DEFAULT = 0; + + /** the maximum norm for the incoming weights. */ + private final double maxValue; + /** integer, axis along which to calculate weight norms. */ + private final int[] axes; + + /** + * Create a MaxNorm constraint using {@link #MAX_VALUE_DEFAULT} for the max value and {@link + * #AXIS_DEFAULT} for the axis. + * + * @param tf the TensorFlow Ops + */ + public MaxNorm(Ops tf) { + this(tf, MAX_VALUE_DEFAULT, AXIS_DEFAULT); + } + + /** + * Create a MaxNorm constraint using {@link #AXIS_DEFAULT} for the axis. + * + * @param tf the TensorFlow Ops + * @param maxValue the maximum norm for the incoming weights. + */ + public MaxNorm(Ops tf, double maxValue) { + this(tf, maxValue, AXIS_DEFAULT); + } + + /** + * Create a MaxNorm constraint + * + * @param tf the TensorFlow Ops + * @param maxValue the maximum norm for the incoming weights. + * @param axis axis along which to calculate weight norms. + */ + public MaxNorm(Ops tf, double maxValue, int axis) { + this(tf, maxValue, new int[] {axis}); + } + + /** + * Create a MaxNorm constraint + * + * @param tf the TensorFlow Ops + * @param maxValue the maximum norm for the incoming weights. + * @param axes axes along which to calculate weight norms. + */ + public MaxNorm(Ops tf, double maxValue, int[] axes) { + super(tf); + this.maxValue = maxValue; + this.axes = axes; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand weights) { + Ops tf = getTF(); + Class type = weights.type(); + Operand norms = norm(weights, getAxes()); + Operand desired = clip(norms, 0f, this.getMaxValue()); + + return tf.math.mul( + weights, tf.math.div(desired, tf.math.add(cast(tf, tf.constant(EPSILON), type), norms))); + } + + /** + * Gets the max value + * + * @return the maxValue + */ + public double getMaxValue() { + return maxValue; + } + + /** + * Gets the axes + * + * @return the axes + */ + public int[] getAxes() { + return axes; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java new file mode 100644 index 00000000000..04b21572e55 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/MinMaxNorm.java @@ -0,0 +1,154 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** Constrains the weights to have the norm between a lower bound and an upper bound. */ +public class MinMaxNorm extends Constraint { + public static final double MIN_VALUE_DEFAULT = 0.0; + public static final double MAX_VALUE_DEFAULT = 1.0; + public static final double RATE_DEFAULT = 1.0; + public static final int AXIS_DEFAULT = 0; + + /** the minimum norm for the incoming weights. */ + private final double minValue; + /** the maximum norm for the incoming weights. */ + private final double maxValue; + + /** + * rate for enforcing the constraint: weights will be rescaled to yield (1 - rate) * norm + rate * + * norm.clip(min_value, max_value). Effectively, this means that rate=1.0 stands for strict + * enforcement of the constraint, while rate<1.0 means that weights will be rescaled at each step + * to slowly move towards a value inside the desired interval. + */ + private final double rate; + + /** axis along which to calculate weight norms. */ + private final int[] axes; + + /** + * Create a MinMaxNorm constraint using {@link #MIN_VALUE_DEFAULT} for the min value, {@link + * #MAX_VALUE_DEFAULT} for the max value, {@link #RATE_DEFAULT} for the rate and {@link + * #AXIS_DEFAULT} for the axis + * + * @param tf the TensorFlow Ops + */ + public MinMaxNorm(Ops tf) { + this(tf, MIN_VALUE_DEFAULT, MAX_VALUE_DEFAULT, RATE_DEFAULT, AXIS_DEFAULT); + } + + /** + * Create a MinMaxNorm constraint using {@link #RATE_DEFAULT} for the rate and {@link + * #AXIS_DEFAULT} for the axis + * + * @param tf the TensorFlow Ops + * @param minValue the minimum norm for the incoming weights. + * @param maxValue the maximum norm for the incoming weights. + */ + public MinMaxNorm(Ops tf, double minValue, double maxValue) { + this(tf, minValue, maxValue, RATE_DEFAULT, AXIS_DEFAULT); + } + + /** + * Create a MinMaxNorm constraint + * + * @param tf the TensorFlow Ops + * @param minValue the minimum norm for the incoming weights. + * @param maxValue the maximum norm for the incoming weights. + * @param rate the rate for enforcing the constraint. + * @param axis integer, axis along which to calculate weight norms. + */ + public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int axis) { + this(tf, minValue, maxValue, rate, new int[] {axis}); + } + /** + * Create a MinMaxNorm constraint + * + * @param tf the TensorFlow Ops + * @param minValue the minimum norm for the incoming weights. + * @param maxValue the maximum norm for the incoming weights. + * @param rate the rate for enforcing the constraint. + * @param axes integer, axis along which to calculate weight norms. + */ + public MinMaxNorm(Ops tf, double minValue, double maxValue, double rate, int[] axes) { + super(tf); + this.minValue = minValue; + this.maxValue = maxValue; + this.rate = rate; + this.axes = axes; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand weights) { + Class type = weights.type(); + Ops tf = getTF(); + Operand norms = norm(weights, getAxes()); + Operand desired = + tf.math.add( + tf.math.mul( + tf.dtypes.cast(tf.constant(this.getRate()), type), + clip(norms, this.getMinValue(), this.getMaxValue())), + tf.math.mul( + tf.math.sub( + tf.dtypes.cast(tf.constant(1), type), + tf.dtypes.cast(tf.constant(this.getRate()), type)), + norms)); + + return tf.math.mul( + weights, tf.math.div(desired, tf.math.add(cast(tf, tf.constant(EPSILON), type), norms))); + } + + /** + * Gets the minValue + * + * @return the minValue + */ + public double getMinValue() { + return minValue; + } + + /** + * Gets the maxValue + * + * @return the maxValue + */ + public double getMaxValue() { + return maxValue; + } + + /** + * Gets the rate + * + * @return the rate + */ + public double getRate() { + return rate; + } + + /** + * Gets the axes + * + * @return the axes + */ + public int[] getAxes() { + return axes; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java new file mode 100644 index 00000000000..0194b2fadb6 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/NonNeg.java @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** Constrains the weights to be non-negative. */ +public class NonNeg extends Constraint { + + /** + * Create a NonNeg constraint + * + * @param tf the TensorFlow Ops + */ + public NonNeg(Ops tf) { + super(tf); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand weights) { + Ops tf = getTF(); + Class type = weights.type(); + return tf.math.mul( + weights, + tf.dtypes.cast(tf.math.greaterEqual(weights, tf.dtypes.cast(tf.constant(0), type)), type)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java new file mode 100644 index 00000000000..70bb1a59785 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/UnitNorm.java @@ -0,0 +1,78 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.constraints; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** Constrains the weights to have unit norm. */ +public class UnitNorm extends Constraint { + public static final int AXIS_DEFAULT = 0; + + /** integer, axis along which to calculate weight norms. */ + private final int[] axes; + + /** + * Create a UnitNorm Constraint with the axis set to {@link #AXIS_DEFAULT} + * + * @param tf the TensorFlow Ops + */ + public UnitNorm(Ops tf) { + this(tf, AXIS_DEFAULT); + } + + /** + * Create a UnitNorm Constraint + * + * @param tf the TensorFlow Ops + * @param axis axis along which to calculate weight norms. + */ + public UnitNorm(Ops tf, int axis) { + this(tf, new int[] {axis}); + } + + /** + * Create a UnitNorm Constraint + * + * @param tf the TensorFlow Ops + * @param axes axes along which to calculate weight norms. + */ + public UnitNorm(Ops tf, int[] axes) { + super(tf); + this.axes = axes; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand weights) { + Class type = weights.type(); + + Ops tf = getTF(); + return tf.math.div( + weights, tf.math.add(cast(tf, tf.constant(EPSILON), type), norm(weights, getAxes()))); + } + + /** + * Gets the axes + * + * @return the axes + */ + public int[] getAxes() { + return axes; + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java new file mode 100644 index 00000000000..1f80388e88f --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MaxNormTest.java @@ -0,0 +1,65 @@ +package org.tensorflow.framework.constraints; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +class MaxNormTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + private float[] getSampleArray() { + Random rand = new Random(3537L); + float[] result = new float[100 * 100]; + for (int i = 0; i < result.length; i++) { + result[i] = rand.nextFloat() * 100 - 50; + } + result[0] = 0; + return result; + } + + /** Test of call method, of class MaxNorm. */ + @Test + public void testCall() { + double[] testValues = {0.1, 0.5, 3, 8, 1e-7}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + final float[] array = getSampleArray(); + Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); + for (AtomicInteger i = new AtomicInteger(); + i.get() < testValues.length; + i.getAndIncrement()) { + MaxNorm instance = new MaxNorm(tf, testValues[i.get()]); + Operand result = instance.call(weights); + session.evaluate(result, v -> v.floatValue() <= testValues[i.get()]); + } + } + } + /** Test of call method, of class MaxNorm. */ + @Test + public void testCall1() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MaxNorm instance = new MaxNorm(tf, 2.0); + Operand weights = + tf.constant( + new float[][] { + {0, 1, 3, 3}, {0, 0, 0, 3}, {0, 0, 0, 3}, + }); + Operand result = instance.call(weights); + float[] expected = { + 0, 1, 2, 1.1547005f, + 0, 0, 0, 1.1547005f, + 0, 0, 0, 1.1547005f + }; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java new file mode 100644 index 00000000000..8c2c3a54ff9 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java @@ -0,0 +1,64 @@ +package org.tensorflow.framework.constraints; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.ND; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +class MinMaxNormTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + private float[] getSampleArray() { + Random rand = new Random(3537L); + float[] result = new float[100 * 100]; + for (int i = 0; i < result.length; i++) { + result[i] = rand.nextFloat() * 100 - 50; + } + result[0] = 0; + return result; + } + + /** Test of call method, of class MinMaxNorm. */ + @Test + public void testCall() { + float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + final float[] array = getSampleArray(); + Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); + for (AtomicInteger i = new AtomicInteger(); + i.get() < testValues.length; + i.getAndIncrement()) { + MinMaxNorm instance = new MinMaxNorm(tf, testValues[i.get()], testValues[i.get()] * 2); + Operand result = instance.call(weights); + if (tfMode == TestSession.Mode.EAGER) + evaluate(session, result.asTensor(), testValues[i.get()]); + else + try (TFloat32 tensor = + (TFloat32) session.getGraphSession().runner().fetch(result).run().get(0)) { + evaluate(session, tensor, testValues[i.get()]); + } + } + } + } + + private void evaluate(TestSession session, TFloat32 tensor, float m) { + FloatNdArray tensorArray = NdArrays.ofFloats(tensor.shape()); + tensor.copyTo(tensorArray); + tensorArray = ND.square(tensorArray); + FloatNdArray normArray = ND.sum(tensorArray, 0); + FloatNdArray normOfNormalized = ND.sqrt(normArray); + session.evaluate( + normOfNormalized, (f) -> f.floatValue() >= m && f.floatValue() <= m * 2f + 1e-5f); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java new file mode 100644 index 00000000000..6a6fdc13536 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/NonNegTest.java @@ -0,0 +1,40 @@ +package org.tensorflow.framework.constraints; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class NonNegTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testTFloat32() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] array = {{-1, 2, -3, 4}, {-10, 11, 12, -13}}; + Operand weights = tf.constant(array); + NonNeg instance = new NonNeg(tf); + Operand result = instance.call(weights); + float[] expected = {0, 2, 0, 4, 0, 11, 12, 0}; + session.evaluate(expected, result); + } + } + + @Test + public void testTFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + final double[][] array = {{-1, 2, -3, 4}, {-10, 11, 12, -13}}; + Operand weights = tf.constant(array); + NonNeg instance = new NonNeg(tf); + Operand result = instance.call(weights); + double[] expected = {0, 2, 0, 4, 0, 11, 12, 0}; + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java new file mode 100644 index 00000000000..6437ebcd760 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/UnitNormTest.java @@ -0,0 +1,60 @@ +package org.tensorflow.framework.constraints; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class UnitNormTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** Test of call method, of class UnitNorm. */ + @Test + public void testTFloat32() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][][] array = { + {{0.14517927f, 0.2574964f, 0.2291325f}, {0.9145494f, 0.9378068f, 0.6827883f}}, + {{0.27121753f, 0.08317473f, 0.3770739f}, {0.25451255f, 0.18511271f, 0.5620538f}}, + {{0.40101776f, 0.25205433f, 0.05103926f}, {0.08764106f, 0.00593294f, 0.37244815f}} + }; + float[][][] expectedArray = { + {{0.1567809f, 0.2647736f, 0.31814702f}, {0.9876333f, 0.9643105f, 0.94804124f}}, + {{0.72920675f, 0.40984813f, 0.55712338f}, {0.68429305f, 0.91215323f, 0.83042956f}}, + {{0.97694125f, 0.99972269f, 0.13576831f}, {0.21350717f, 0.02353181f, 0.99074035f}} + }; + + Operand weights = tf.constant(array); + UnitNorm instance = new UnitNorm(tf, 1); + Operand result = instance.call(weights); + Operand expected = tf.constant(expectedArray); + session.evaluate(expected, result); + } + } + /** Test of call method, of class UnitNorm. */ + @Test + public void testCallTFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + double[][][] array = { + {{0.14517927, 0.2574964, 0.2291325}, {0.9145494, 0.9378068, 0.6827883}}, + {{0.27121753, 0.08317473, 0.3770739}, {0.25451255, 0.18511271, 0.5620538}}, + {{0.40101776, 0.25205433, 0.05103926}, {0.08764106, 0.00593294, 0.37244815}} + }; + double[][][] expectedArray = { + {{0.1567809, 0.2647736, 0.31814702}, {0.9876333, 0.9643105, 0.94804124}}, + {{0.72920675, 0.40984813, 0.55712338}, {0.68429305, 0.91215323, 0.83042956}}, + {{0.97694125, 0.99972269, 0.13576831}, {0.21350717, 0.02353181, 0.99074035}} + }; + UnitNorm instance = new UnitNorm(tf, 1); + Operand weights = tf.constant(array); + Operand result = instance.call(weights); + Operand expected = tf.constant(expectedArray); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java index 0503a41dfc2..ef8bb71d724 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java @@ -14,10 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.*; import java.util.Arrays; import java.util.concurrent.atomic.AtomicBoolean; @@ -103,6 +100,23 @@ public static FloatNdArray sqrt(FloatNdArray a) { return result; } + /** + * Gets the square root of an array. + * + * @param a the array + * @return the square root of the array. + */ + public static DoubleNdArray sqrt(DoubleNdArray a) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + int nDims = a.shape().numDimensions(); + a.elements(nDims - 1) + .forEachIndexed( + (idx, v) -> { + result.setDouble(Math.sqrt(v.getDouble()), idx); + }); + return result; + } + /** * Gets the square of an array. * @@ -120,6 +134,23 @@ public static FloatNdArray square(FloatNdArray a) { return result; } + /** + * Gets the square of an array. + * + * @param a the array + * @return the square of the array. + */ + public static DoubleNdArray square(DoubleNdArray a) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + int nDims = a.shape().numDimensions(); + a.elements(nDims - 1) + .forEachIndexed( + (idx, v) -> { + result.setDouble(v.getDouble() * v.getDouble(), idx); + }); + return result; + } + /** * Adds two arrays * @@ -568,6 +599,18 @@ public static FloatNdArray sum(FloatNdArray a) { return NdArrays.scalarOf(sum.get()); } + /** + * Sum all elements of an array + * + * @param a the array + * @return an a array with one element containing the sum. + */ + public static DoubleNdArray sum(DoubleNdArray a) { + AtomicReference sum = new AtomicReference(0D); + a.scalars().forEach(f -> sum.set(sum.get() + f.getDouble())); + return NdArrays.scalarOf(sum.get()); + } + /** * Sum all elements of an array based on the specified axis * @@ -579,6 +622,17 @@ public static FloatNdArray sum(FloatNdArray a, int axis) { return sum(a, axis, false); } + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axis the axis to sum + * @return an a array the sum over the axis less the diemsnion + */ + public static DoubleNdArray sum(DoubleNdArray a, int axis) { + return sum(a, axis, false); + } + /** * Sum all elements of an array based on the specified axis * @@ -618,6 +672,45 @@ public static FloatNdArray sum(FloatNdArray a, int axis, boolean keepDims) { } } + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axis the axis to sum + * @param keepDims indicates whether the dimensions over the sum should be kept or not. + * @return an a array the sum over the axis + */ + public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { + Shape shape = a.shape(); + int nDims = shape.numDimensions(); + int xis = nDims - 1 - axis; + long totalSize = shape.size(); + long axisSize = shape.size(xis); + final double[] sums = new double[(int) axisSize]; + + a.scalars() + .forEachIndexed( + (idx, f) -> { + sums[(int) idx[xis]] += f.getDouble(); + }); + + if (keepDims) { + long[] newDims = shape.asArray(); + newDims[axis] = 1; + final AtomicInteger counter = new AtomicInteger(); + DoubleNdArray arrayK = NdArrays.ofDoubles(Shape.of(newDims)); + arrayK + .elements(newDims.length - 1) + .forEachIndexed( + (idx, v) -> { + v.setDouble(sums[counter.getAndAdd(1)]); + }); + return arrayK; + } else { + return NdArrays.vectorOf(sums); + } + } + /** * Sum all elements of an array based on the specified axis * diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java index db39a330522..2c252d467c7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.utils; import org.tensorflow.*; +import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -493,15 +494,16 @@ public void evaluate(FloatNdArray input, Predicate predicate) { } /** - * Print the input to standard out + * Evaluates the input against the expected value * - - * @param input the operand to print - * @param the data type of the input + * @param input the operand to evaluate + * @param predicate The Predicate that evaluates the each value from input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void print(Operand input) { - print(new PrintWriter(new OutputStreamWriter(System.out)), input.asOutput()); + public void evaluate(DoubleNdArray input, Predicate predicate) { + input.scalars().forEach(f -> assertTrue(predicate.test(f.getDouble()))); } + /** * Print the input *