diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs index 8a3b3046105551..3a3be2c1f2dab7 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/AllowedRecordType.cs @@ -27,12 +27,16 @@ internal enum AllowedRecordTypes : uint Nulls = ObjectNull | ObjectNullMultiple256 | ObjectNullMultiple, /// - /// Any .NET object (a primitive, a reference type, a reference or single null). + /// Any .NET object (a class, primitive type or an array). /// AnyObject = MemberPrimitiveTyped | ArraySingleObject | ArraySinglePrimitive | ArraySingleString | BinaryArray | ClassWithId | ClassWithMembersAndTypes | SystemClassWithMembersAndTypes | BinaryObjectString - | MemberReference - | ObjectNull, + | MemberReference, + + /// + /// Any .NET object or a reference or a single null. + /// + AnyObjectOrNullOrReference = AnyObject | ObjectNull | MemberReference, } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs index 37e94842719a90..128eba6b500c63 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleObjectRecord.cs @@ -68,7 +68,7 @@ internal static ArraySingleObjectRecord Decode(BinaryReader reader) internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetAllowedRecordType() { // An array of objects can contain any Object or multiple nulls. - const AllowedRecordTypes Allowed = AllowedRecordTypes.AnyObject | AllowedRecordTypes.Nulls; + const AllowedRecordTypes Allowed = AllowedRecordTypes.AnyObjectOrNullOrReference | AllowedRecordTypes.Nulls; return (Allowed, default); } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs index de248bcef76755..33f74d548ed016 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/ArraySingleStringRecord.cs @@ -54,11 +54,6 @@ internal override (AllowedRecordTypes allowed, PrimitiveType primitiveType) GetA if (record is MemberReferenceRecord memberReference) { record = memberReference.GetReferencedRecord(); - - if (record is not BinaryObjectStringRecord) - { - ThrowHelper.ThrowInvalidReference(); - } } if (record is BinaryObjectStringRecord stringRecord) diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberReferenceRecord.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberReferenceRecord.cs index 14bd4e7ff1f2d1..acb4d90f6b3d93 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberReferenceRecord.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberReferenceRecord.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; +using System.Formats.Nrbf.Utils; using System.IO; using System.Reflection.Metadata; @@ -15,16 +16,19 @@ namespace System.Formats.Nrbf; /// internal sealed class MemberReferenceRecord : SerializationRecord { - private MemberReferenceRecord(SerializationRecordId reference, RecordMap recordMap) + private MemberReferenceRecord(SerializationRecordId reference, RecordMap recordMap, AllowedRecordTypes referencedRecordType) { Reference = reference; RecordMap = recordMap; + ReferencedRecordType = referencedRecordType; } public override SerializationRecordType RecordType => SerializationRecordType.MemberReference; internal SerializationRecordId Reference { get; } + private AllowedRecordTypes ReferencedRecordType { get; } + private RecordMap RecordMap { get; } // MemberReferenceRecord has no Id, which makes it impossible to create a cycle @@ -35,8 +39,26 @@ private MemberReferenceRecord(SerializationRecordId reference, RecordMap recordM internal override object? GetValue() => GetReferencedRecord().GetValue(); - internal static MemberReferenceRecord Decode(BinaryReader reader, RecordMap recordMap) - => new(SerializationRecordId.Decode(reader), recordMap); + internal static MemberReferenceRecord Decode(BinaryReader reader, RecordMap recordMap, AllowedRecordTypes allowed) + { + SerializationRecordId reference = SerializationRecordId.Decode(reader); + + // We were supposed to decode a record of specific type or a reference to it. + // Since a reference was decoded and we don't know when the referenced record will be provided. + // We just store the allowed record type and are going to check it later. + AllowedRecordTypes referencedRecordType = allowed & ~(AllowedRecordTypes.MemberReference | AllowedRecordTypes.Nulls); + + return new MemberReferenceRecord(reference, recordMap, referencedRecordType); + } internal SerializationRecord GetReferencedRecord() => RecordMap.GetRecord(Reference); + + internal void VerifyReferencedRecordType(SerializationRecord serializationRecord) + { + if (((uint)ReferencedRecordType & (1u << (byte)serializationRecord.RecordType)) == 0) + { + // We expected a reference to a record of a different type. + ThrowHelper.ThrowInvalidReference(); + } + } } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs index 9843a0b71f04c0..8fcae2bf2aafea 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/MemberTypeInfo.cs @@ -74,6 +74,8 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt | AllowedRecordTypes.ObjectNull | AllowedRecordTypes.MemberReference; const AllowedRecordTypes ObjectArray = AllowedRecordTypes.ArraySingleObject | AllowedRecordTypes.ObjectNull | AllowedRecordTypes.MemberReference; + const AllowedRecordTypes NonPrimitiveArray = AllowedRecordTypes.BinaryArray + | AllowedRecordTypes.ObjectNull | AllowedRecordTypes.MemberReference; // Every string can be a string, a null or a reference (to a string) const AllowedRecordTypes Strings = AllowedRecordTypes.BinaryObjectString @@ -92,13 +94,53 @@ internal static MemberTypeInfo Decode(BinaryReader reader, int count, PayloadOpt { BinaryType.Primitive => (default, (PrimitiveType)additionalInfo!), BinaryType.String => (Strings, default), - BinaryType.Object => (AllowedRecordTypes.AnyObject, default), + BinaryType.Object => (AllowedRecordTypes.AnyObjectOrNullOrReference, default), BinaryType.StringArray => (StringArray, default), BinaryType.PrimitiveArray => (PrimitiveArray, default), - BinaryType.Class => (NonSystemClass, default), - BinaryType.SystemClass => (SystemClass, default), - _ => (ObjectArray, default) + BinaryType.Class => (((ClassTypeInfo)additionalInfo!).TypeName.IsArray ? NonPrimitiveArray : NonSystemClass, default), + BinaryType.SystemClass => (MapSystemClassTypeName((TypeName)additionalInfo!), default), + _ => (ObjectArray, default), }; + + static AllowedRecordTypes MapSystemClassTypeName(TypeName typeName) + { + if (!typeName.IsArray) + { + return SystemClass; + } + else if (typeName.IsSZArray) + { + TypeName elementTypeName = typeName.GetElementType(); + if (elementTypeName.IsSimple && elementTypeName.FullName.StartsWith("System.", StringComparison.Ordinal)) + { + switch (elementTypeName.FullName) + { + case "System.Boolean": + case "System.Byte": + case "System.SByte": + case "System.Char": + case "System.Int16": + case "System.UInt16": + case "System.Int32": + case "System.UInt32": + case "System.Int64": + case "System.UInt64": + case "System.Single": + case "System.Double": + case "System.Decimal": + case "System.DateTime": + case "System.TimeSpan": + // BinaryFormatter should use BinaryType.PrimitiveArray for these primitive types, + // but it uses BinaryType.SystemClass and we need this workaround. + return PrimitiveArray; + default: + break; + } + } + } + + return NonPrimitiveArray; + } } internal bool ShouldBeRepresentedAsArrayOfClassRecords() diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs index de4b24b6e46e1b..8167ea594c8e64 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/NrbfDecoder.cs @@ -167,11 +167,14 @@ private static SerializationRecord Decode(BinaryReader reader, PayloadOptions op Stack readStack = new(); RecordMap recordMap = new(); - // Everything has to start with a header - var header = (SerializedStreamHeaderRecord)DecodeNext(reader, recordMap, AllowedRecordTypes.SerializedStreamHeader, options, out _); - // and can be followed by any Object, BinaryLibrary and a MessageEnd. - const AllowedRecordTypes Allowed = AllowedRecordTypes.AnyObject - | AllowedRecordTypes.BinaryLibrary | AllowedRecordTypes.MessageEnd; + // Every NRBF payload has to start with a header + AllowedRecordTypes allowed = AllowedRecordTypes.SerializedStreamHeader; + var header = (SerializedStreamHeaderRecord)DecodeNext(reader, recordMap, allowed, options, out _); + + // The root can be any Object or BinaryLibrary, but not a reference. + allowed = AllowedRecordTypes.AnyObject | AllowedRecordTypes.BinaryLibrary; + SerializationRecord rootRecord = DecodeNext(reader, recordMap, allowed, options, out _); + PushFirstNestedRecordInfo(rootRecord, readStack); SerializationRecordType recordType; SerializationRecord nextRecord; @@ -184,16 +187,7 @@ private static SerializationRecord Decode(BinaryReader reader, PayloadOptions op if (nextInfo.Allowed != AllowedRecordTypes.None) { // Decode the next Record - do - { - nextRecord = DecodeNext(reader, recordMap, nextInfo.Allowed, options, out _); - // BinaryLibrary often precedes class records. - // It has been already added to the RecordMap and it must not be added - // to the array record, so simply read next record. - // It's possible to read multiple BinaryLibraryRecord in a row, hence the loop. - } - while (nextRecord is BinaryLibraryRecord); - + nextRecord = DecodeNext(reader, recordMap, nextInfo.Allowed, options, out _); // Handle it: // - add to the parent records list, // - push next info if there are remaining nested records to read. @@ -210,7 +204,20 @@ private static SerializationRecord Decode(BinaryReader reader, PayloadOptions op } } - nextRecord = DecodeNext(reader, recordMap, Allowed, options, out recordType); + if (recordMap.UnresolvedReferences == 0) + { + // There are no unresolved references, so the End is the only allowed record. + allowed = AllowedRecordTypes.MessageEnd; + } + else + { + // There are unresolved references and we don't know in what order they are going to appear. + // We allow for any Object (which does not include references or nulls). + // The actual type validation is going to be performed by RecordMap.Add. + allowed = AllowedRecordTypes.AnyObject | AllowedRecordTypes.BinaryLibrary; + } + + nextRecord = DecodeNext(reader, recordMap, allowed, options, out recordType, isReferencedRecord: true); PushFirstNestedRecordInfo(nextRecord, readStack); } while (recordType != SerializationRecordType.MessageEnd); @@ -220,31 +227,41 @@ private static SerializationRecord Decode(BinaryReader reader, PayloadOptions op } private static SerializationRecord DecodeNext(BinaryReader reader, RecordMap recordMap, - AllowedRecordTypes allowed, PayloadOptions options, out SerializationRecordType recordType) + AllowedRecordTypes allowed, PayloadOptions options, out SerializationRecordType recordType, bool isReferencedRecord = false) { - recordType = reader.ReadSerializationRecordType(allowed); + SerializationRecord? record; - SerializationRecord record = recordType switch + do { - SerializationRecordType.ArraySingleObject => ArraySingleObjectRecord.Decode(reader), - SerializationRecordType.ArraySinglePrimitive => DecodeArraySinglePrimitiveRecord(reader), - SerializationRecordType.ArraySingleString => ArraySingleStringRecord.Decode(reader), - SerializationRecordType.BinaryArray => BinaryArrayRecord.Decode(reader, recordMap, options), - SerializationRecordType.BinaryLibrary => BinaryLibraryRecord.Decode(reader, options), - SerializationRecordType.BinaryObjectString => BinaryObjectStringRecord.Decode(reader), - SerializationRecordType.ClassWithId => ClassWithIdRecord.Decode(reader, recordMap), - SerializationRecordType.ClassWithMembersAndTypes => ClassWithMembersAndTypesRecord.Decode(reader, recordMap, options), - SerializationRecordType.MemberPrimitiveTyped => DecodeMemberPrimitiveTypedRecord(reader), - SerializationRecordType.MemberReference => MemberReferenceRecord.Decode(reader, recordMap), - SerializationRecordType.MessageEnd => MessageEndRecord.Singleton, - SerializationRecordType.ObjectNull => ObjectNullRecord.Instance, - SerializationRecordType.ObjectNullMultiple => ObjectNullMultipleRecord.Decode(reader), - SerializationRecordType.ObjectNullMultiple256 => ObjectNullMultiple256Record.Decode(reader), - SerializationRecordType.SerializedStreamHeader => SerializedStreamHeaderRecord.Decode(reader), - _ => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options), - }; + recordType = reader.ReadSerializationRecordType(allowed); - recordMap.Add(record); + record = recordType switch + { + SerializationRecordType.ArraySingleObject => ArraySingleObjectRecord.Decode(reader), + SerializationRecordType.ArraySinglePrimitive => DecodeArraySinglePrimitiveRecord(reader), + SerializationRecordType.ArraySingleString => ArraySingleStringRecord.Decode(reader), + SerializationRecordType.BinaryArray => BinaryArrayRecord.Decode(reader, recordMap, options), + SerializationRecordType.BinaryLibrary => BinaryLibraryRecord.Decode(reader, options), + SerializationRecordType.BinaryObjectString => BinaryObjectStringRecord.Decode(reader), + SerializationRecordType.ClassWithId => ClassWithIdRecord.Decode(reader, recordMap), + SerializationRecordType.ClassWithMembersAndTypes => ClassWithMembersAndTypesRecord.Decode(reader, recordMap, options), + SerializationRecordType.MemberPrimitiveTyped => DecodeMemberPrimitiveTypedRecord(reader), + SerializationRecordType.MemberReference => MemberReferenceRecord.Decode(reader, recordMap, allowed), + SerializationRecordType.MessageEnd => MessageEndRecord.Singleton, + SerializationRecordType.ObjectNull => ObjectNullRecord.Instance, + SerializationRecordType.ObjectNullMultiple => ObjectNullMultipleRecord.Decode(reader), + SerializationRecordType.ObjectNullMultiple256 => ObjectNullMultiple256Record.Decode(reader), + SerializationRecordType.SerializedStreamHeader => SerializedStreamHeaderRecord.Decode(reader), + _ => SystemClassWithMembersAndTypesRecord.Decode(reader, recordMap, options), + }; + + recordMap.Add(record, isReferencedRecord); + + // BinaryLibrary often precedes class records. + // It has been already added to the RecordMap and it must not be added + // to the array record or class member values, so simply read next record. + // It's possible to read multiple BinaryLibraryRecord in a row, hence the loop. + } while (recordType == SerializationRecordType.BinaryLibrary); return record; } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs index 04a4d0e085048d..c30602cc06b109 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/RecordMap.cs @@ -3,7 +3,9 @@ using System.Collections; using System.Collections.Generic; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Formats.Nrbf.Utils; using System.Runtime.InteropServices; using System.Runtime.Serialization; @@ -21,6 +23,8 @@ internal sealed class RecordMap : IReadOnlyDictionary _map[objectId]; + internal int UnresolvedReferences { get; private set; } + public bool ContainsKey(SerializationRecordId key) => _map.ContainsKey(key); public bool TryGetValue(SerializationRecordId key, [MaybeNullWhen(false)] out SerializationRecord value) => _map.TryGetValue(key, out value); @@ -29,34 +33,83 @@ internal sealed class RecordMap : IReadOnlyDictionary _map.GetEnumerator(); - internal void Add(SerializationRecord record) + internal void Add(SerializationRecord record, bool isReferencedRecord) { + switch (record.RecordType) + { + case SerializationRecordType.SerializedStreamHeader: + case SerializationRecordType.ObjectNull: + case SerializationRecordType.MessageEnd: + case SerializationRecordType.ObjectNullMultiple256: + case SerializationRecordType.ObjectNullMultiple: + case SerializationRecordType.MemberPrimitiveTyped when record.Id.IsDefault: + // These records have no Id and don't need any verification. + Debug.Assert(record.Id.IsDefault); + return; + case SerializationRecordType.BinaryLibrary: + if (!TryAdd(record.Id, record)) + { + ThrowHelper.ThrowDuplicateSerializationRecordId(record.Id); + } + return; + case SerializationRecordType.MemberReference: + MemberReferenceRecord memberReferenceRecord = (MemberReferenceRecord)record; + + if (_map.TryGetValue(memberReferenceRecord.Reference, out SerializationRecord? stored)) + { + if (stored.RecordType != SerializationRecordType.MemberReference) + { + // When reference was stored, we have persisted the allowed record type. + // Now is the time to check if the provided record matches expectations. + memberReferenceRecord.VerifyReferencedRecordType(stored); + } + } + else + { + // We store the reference now and when the record is provided we are going to perform type check. + _map.Add(memberReferenceRecord.Reference, record); + UnresolvedReferences++; + } + return; + default: + break; + } + + Debug.Assert(!record.Id.IsDefault); + // From https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nrbf/0a192be0-58a1-41d0-8a54-9c91db0ab7bf: // "If the ObjectId is not referenced by any MemberReference in the serialization stream, // then the ObjectId SHOULD be positive, but MAY be negative." - if (!record.Id.Equals(SerializationRecordId.NoId)) + if (isReferencedRecord) { if (record.Id._id < 0) { - // Negative record Ids should never be referenced. Duplicate negative ids can be - // exported by the writer. The root object Id can be negative. - _map[record.Id] = record; + // Negative record Ids should never be referenced. + ThrowHelper.ThrowInvalidReference(); + } + else if (!_map.TryGetValue(record.Id, out SerializationRecord? stored) || stored is not MemberReferenceRecord memberReferenceRecord) + { + // The id was either unexpected or there was no reference stored for it. + ThrowHelper.ThrowForUnexpectedRecordType((byte)record.RecordType); } else { -#if NET - if (_map.TryAdd(record.Id, record)) - { - return; - } -#else - if (!_map.ContainsKey(record.Id)) - { - _map.Add(record.Id, record); - return; - } -#endif - throw new SerializationException(SR.Format(SR.Serialization_DuplicateSerializationRecordId, record.Id)); + memberReferenceRecord.VerifyReferencedRecordType(record); + } + + _map[record.Id] = record; + UnresolvedReferences--; + } + else + { + if (record.Id._id < 0) + { + // Negative ids can be exported by the writer. The root object Id can be negative. + _map[record.Id] = record; + } + else if (!TryAdd(record.Id, record)) + { + ThrowHelper.ThrowDuplicateSerializationRecordId(record.Id); } } } @@ -74,6 +127,20 @@ internal SerializationRecord GetRootRecord(SerializedStreamHeaderRecord header) return rootRecord; } + private bool TryAdd(SerializationRecordId id, SerializationRecord record) + { +#if NET + return _map.TryAdd(id, record); +#else + if (!_map.ContainsKey(id)) + { + _map.Add(id, record); + return true; + } + return false; +#endif + } + internal SerializationRecord GetRecord(SerializationRecordId recordId) => _map.TryGetValue(recordId, out SerializationRecord? record) ? record diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordId.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordId.cs index 7f51525e6e1139..a448a4cebbf8d8 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordId.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/SerializationRecordId.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Formats.Nrbf.Utils; using System.IO; using System.Linq; @@ -15,6 +16,7 @@ namespace System.Formats.Nrbf; /// /// The ID of . /// +[DebuggerDisplay("{_id}")] public readonly struct SerializationRecordId : IEquatable { #pragma warning disable CS0649 // the default value is used on purpose @@ -45,4 +47,6 @@ internal static SerializationRecordId Decode(BinaryReader reader) /// public override int GetHashCode() => HashCode.Combine(_id); + + internal bool IsDefault => _id == default; } diff --git a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/ThrowHelper.cs b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/ThrowHelper.cs index 55febf77533f9d..a05d10ee5ce653 100644 --- a/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/ThrowHelper.cs +++ b/src/libraries/System.Formats.Nrbf/src/System/Formats/Nrbf/Utils/ThrowHelper.cs @@ -17,6 +17,9 @@ internal static void ThrowInvalidReference() internal static void ThrowInvalidTypeName(string name) => throw new SerializationException(SR.Format(SR.Serialization_InvalidTypeName, name)); + internal static void ThrowDuplicateSerializationRecordId(SerializationRecordId id) + => throw new SerializationException(SR.Format(SR.Serialization_DuplicateSerializationRecordId, id._id)); + internal static void ThrowUnexpectedNullRecordCount() => throw new SerializationException(SR.Serialization_UnexpectedNullRecordCount); diff --git a/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs b/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs index b1625c7ca92870..f838df06cdb947 100644 --- a/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/InvalidInputTests.cs @@ -585,4 +585,43 @@ public void SurrogateCharacter() Assert.Throws(() => NrbfDecoder.Decode(stream)); } + + public static IEnumerable ThrowsForOrphanedRecord_Args() + { + SerializationRecordType[] supported = + { + SerializationRecordType.BinaryObjectString, + SerializationRecordType.ArraySingleString, + SerializationRecordType.MemberPrimitiveTyped, + SerializationRecordType.ArraySinglePrimitive, + SerializationRecordType.SystemClassWithMembersAndTypes, + SerializationRecordType.ClassWithMembersAndTypes + }; + + for (int i = 0; i < supported.Length; i++) + { + for (int j = 0; j < supported.Length; j++) + { + yield return new object[] { supported[i], supported[j] }; + } + } + } + + [Theory] + [MemberData(nameof(ThrowsForOrphanedRecord_Args))] + public void ThrowsForOrphanedRecord(SerializationRecordType root, SerializationRecordType orphaned) + { + int objectId = 1; + using MemoryStream stream = new(); + BinaryWriter writer = new(stream, Encoding.UTF8); + + WriteSerializedStreamHeader(writer); + WriteValidRecord(writer, root, ref objectId); + WriteValidRecord(writer, orphaned, ref objectId); + + writer.Write((byte)SerializationRecordType.MessageEnd); + + stream.Position = 0; + Assert.Throws(() => NrbfDecoder.Decode(stream)); + } } diff --git a/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs b/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs index 0c7bd2045fa1f8..236671f14db305 100644 --- a/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs +++ b/src/libraries/System.Formats.Nrbf/tests/ReadTests.cs @@ -60,4 +60,51 @@ protected static void WriteBinaryLibrary(BinaryWriter writer, int objectId, stri writer.Write(objectId); writer.Write(libraryName); } + + protected static void WriteValidRecord(BinaryWriter writer, SerializationRecordType recordType, ref int objectId) + { + const int LibraryId = 12345; + if (recordType == SerializationRecordType.ClassWithMembersAndTypes) + { + WriteBinaryLibrary(writer, LibraryId, "libName"); + } + + writer.Write((byte)recordType); + writer.Write(objectId++); + + if (recordType == SerializationRecordType.BinaryObjectString) + { + writer.Write("aString"); + } + else if (recordType == SerializationRecordType.ArraySingleString) + { + writer.Write(2); // array length + WriteValidRecord(writer, SerializationRecordType.BinaryObjectString, ref objectId); + WriteValidRecord(writer, SerializationRecordType.BinaryObjectString, ref objectId); + } + else if (recordType == SerializationRecordType.MemberPrimitiveTyped) + { + writer.Write((byte)PrimitiveType.Boolean); + writer.Write(true); + } + else if (recordType == SerializationRecordType.ArraySinglePrimitive) + { + writer.Write(3); // array length + writer.Write((byte)PrimitiveType.Int32); + writer.Write(1); + writer.Write(2); + writer.Write(3); + } + else if (recordType == SerializationRecordType.SystemClassWithMembersAndTypes) + { + writer.Write("TypeName"); + writer.Write(0); // member count + } + else if (recordType == SerializationRecordType.ClassWithMembersAndTypes) + { + writer.Write("TypeName"); + writer.Write(0); // member count + writer.Write(LibraryId); + } + } }