diff --git a/spring-aop/src/main/java/org/springframework/aop/framework/CoroutinesUtils.java b/spring-aop/src/main/java/org/springframework/aop/framework/CoroutinesUtils.java index 3cb74cbdf782..083c869434a6 100644 --- a/spring-aop/src/main/java/org/springframework/aop/framework/CoroutinesUtils.java +++ b/spring-aop/src/main/java/org/springframework/aop/framework/CoroutinesUtils.java @@ -17,6 +17,7 @@ package org.springframework.aop.framework; import kotlin.coroutines.Continuation; +import kotlinx.coroutines.flow.Flow; import kotlinx.coroutines.reactive.ReactiveFlowKt; import kotlinx.coroutines.reactor.MonoKt; import org.jspecify.annotations.Nullable; @@ -35,6 +36,9 @@ static Object asFlow(@Nullable Object publisher) { if (publisher instanceof Publisher rsPublisher) { return ReactiveFlowKt.asFlow(rsPublisher); } + else if (publisher instanceof Flow) { + return publisher; + } else { throw new IllegalArgumentException("Not a Reactive Streams Publisher: " + publisher); } diff --git a/spring-aop/src/test/kotlin/org/springframework/aop/framework/CoroutinesUtilsTests.kt b/spring-aop/src/test/kotlin/org/springframework/aop/framework/CoroutinesUtilsTests.kt index 188b72f8b1ea..0cd0378fd311 100644 --- a/spring-aop/src/test/kotlin/org/springframework/aop/framework/CoroutinesUtilsTests.kt +++ b/spring-aop/src/test/kotlin/org/springframework/aop/framework/CoroutinesUtilsTests.kt @@ -18,6 +18,7 @@ package org.springframework.aop.framework import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.assertj.core.api.Assertions.assertThat @@ -72,4 +73,16 @@ class CoroutinesUtilsTests { } } + @Test + @Suppress("UNCHECKED_CAST") + fun flowAsFlow() { + val value1 = "foo" + val value2 = "bar" + val values = flowOf(value1, value2) + val flow = CoroutinesUtils.asFlow(values) as Flow + runBlocking { + assertThat(flow.toList()).containsExactly(value1, value2) + } + } + } diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/service/RSocketServiceMethod.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/service/RSocketServiceMethod.java index 0380698dda30..4975cc9141ca 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/service/RSocketServiceMethod.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/service/RSocketServiceMethod.java @@ -30,6 +30,7 @@ import reactor.core.publisher.Mono; import org.springframework.core.DefaultParameterNameDiscoverer; +import org.springframework.core.KotlinDetector; import org.springframework.core.MethodParameter; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ReactiveAdapter; @@ -54,6 +55,8 @@ */ final class RSocketServiceMethod { + private static final String COROUTINES_FLOW_CLASS_NAME = "kotlinx.coroutines.flow.Flow"; + private final Method method; private final MethodParameter[] parameters; @@ -82,6 +85,10 @@ private static MethodParameter[] initMethodParameters(Method method) { if (count == 0) { return new MethodParameter[0]; } + if (KotlinDetector.isSuspendingFunction(method)) { + count -= 1; + } + DefaultParameterNameDiscoverer nameDiscoverer = new DefaultParameterNameDiscoverer(); MethodParameter[] parameters = new MethodParameter[count]; for (int i = 0; i < count; i++) { @@ -129,10 +136,16 @@ private static Function initResponseFunction( MethodParameter returnParam = new MethodParameter(method, -1); Class returnType = returnParam.getParameterType(); + boolean isUnwrapped = KotlinDetector.isSuspendingFunction(method) && + !COROUTINES_FLOW_CLASS_NAME.equals(returnParam.getParameterType().getName()); + if (isUnwrapped) { + returnType = Mono.class; + } + ReactiveAdapter reactiveAdapter = reactiveRegistry.getAdapter(returnType); MethodParameter actualParam = (reactiveAdapter != null ? returnParam.nested() : returnParam.nestedIfOptional()); - Class actualType = actualParam.getNestedParameterType(); + Class actualType = isUnwrapped ? actualParam.getParameterType() : actualParam.getNestedParameterType(); Function> responseFunction; if (ClassUtils.isVoidType(actualType) || (reactiveAdapter != null && reactiveAdapter.isNoValue())) { @@ -147,7 +160,8 @@ else if (reactiveAdapter == null) { } else { ParameterizedTypeReference payloadType = - ParameterizedTypeReference.forType(actualParam.getNestedGenericParameterType()); + ParameterizedTypeReference.forType(isUnwrapped ? actualParam.getGenericParameterType() : + actualParam.getNestedGenericParameterType()); responseFunction = values -> ( reactiveAdapter.isMultiValue() ? diff --git a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/service/RSocketServiceProxyFactory.java b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/service/RSocketServiceProxyFactory.java index 6c75440e11b6..796ace49a3f6 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/rsocket/service/RSocketServiceProxyFactory.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/rsocket/service/RSocketServiceProxyFactory.java @@ -31,6 +31,7 @@ import org.springframework.aop.framework.ProxyFactory; import org.springframework.aop.framework.ReflectiveMethodInvocation; +import org.springframework.core.KotlinDetector; import org.springframework.core.MethodIntrospector; import org.springframework.core.ReactiveAdapterRegistry; import org.springframework.core.annotation.AnnotatedElementUtils; @@ -246,7 +247,9 @@ private ServiceMethodInterceptor(List methods) { Method method = invocation.getMethod(); RSocketServiceMethod serviceMethod = this.serviceMethods.get(method); if (serviceMethod != null) { - return serviceMethod.invoke(invocation.getArguments()); + @Nullable Object[] arguments = KotlinDetector.isSuspendingFunction(method) ? + resolveCoroutinesArguments(invocation.getArguments()) : invocation.getArguments(); + return serviceMethod.invoke(arguments); } if (method.isDefault()) { if (invocation instanceof ReflectiveMethodInvocation reflectiveMethodInvocation) { @@ -256,6 +259,12 @@ private ServiceMethodInterceptor(List methods) { } throw new IllegalStateException("Unexpected method invocation: " + method); } + + private static Object[] resolveCoroutinesArguments(@Nullable Object[] args) { + Object[] functionArgs = new Object[args.length - 1]; + System.arraycopy(args, 0, functionArgs, 0, args.length - 1); + return functionArgs; + } } } diff --git a/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/service/RSocketServiceMethodKotlinTests.kt b/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/service/RSocketServiceMethodKotlinTests.kt new file mode 100644 index 000000000000..2764cea7f7a1 --- /dev/null +++ b/spring-messaging/src/test/kotlin/org/springframework/messaging/rsocket/service/RSocketServiceMethodKotlinTests.kt @@ -0,0 +1,134 @@ +/* + * Copyright 2002-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.messaging.rsocket.service + +import io.rsocket.util.DefaultPayload +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flowOf +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.reactive.asFlow +import kotlinx.coroutines.runBlocking +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.springframework.messaging.rsocket.RSocketRequester +import org.springframework.messaging.rsocket.RSocketStrategies +import org.springframework.messaging.rsocket.TestRSocket +import org.springframework.util.MimeTypeUtils.TEXT_PLAIN +import reactor.core.publisher.Flux +import reactor.core.publisher.Mono + +/** + * Kotlin tests for [RSocketServiceMethod]. + * + * @author Dmitry Sulman + */ +class RSocketServiceMethodKotlinTests { + + private lateinit var rsocket: TestRSocket + + private lateinit var proxyFactory: RSocketServiceProxyFactory + + @BeforeEach + fun setUp() { + rsocket = TestRSocket() + val requester = RSocketRequester.wrap(rsocket, TEXT_PLAIN, TEXT_PLAIN, RSocketStrategies.create()) + proxyFactory = RSocketServiceProxyFactory.builder(requester).build() + } + + @Test + fun fireAndForget(): Unit = runBlocking { + val service = proxyFactory.createClient(SuspendingFunctionsService::class.java) + + val requestPayload = "request" + service.fireAndForget(requestPayload) + + assertThat(rsocket.savedMethodName).isEqualTo("fireAndForget") + assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("ff") + assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload) + } + + @Test + fun requestResponse(): Unit = runBlocking { + val service = proxyFactory.createClient(SuspendingFunctionsService::class.java) + + val requestPayload = "request" + val responsePayload = "response" + rsocket.setPayloadMonoToReturn(Mono.just(DefaultPayload.create(responsePayload))) + val response = service.requestResponse(requestPayload) + + assertThat(response).isEqualTo(responsePayload) + assertThat(rsocket.savedMethodName).isEqualTo("requestResponse") + assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("rr") + assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload) + } + + @Test + fun requestStream(): Unit = runBlocking { + val service = proxyFactory.createClient(SuspendingFunctionsService::class.java) + + val requestPayload = "request" + val responsePayload1 = "response1" + val responsePayload2 = "response2" + rsocket.setPayloadFluxToReturn( + Flux.just(DefaultPayload.create(responsePayload1), DefaultPayload.create(responsePayload2))) + val response = service.requestStream(requestPayload).toList() + + assertThat(response).containsExactly(responsePayload1, responsePayload2) + assertThat(rsocket.savedMethodName).isEqualTo("requestStream") + assertThat(rsocket.savedPayload?.metadataUtf8).isEqualTo("rs") + assertThat(rsocket.savedPayload?.dataUtf8).isEqualTo(requestPayload) + } + + @Test + fun requestChannel(): Unit = runBlocking { + val service = proxyFactory.createClient(SuspendingFunctionsService::class.java) + + val requestPayload1 = "request1" + val requestPayload2 = "request2" + val responsePayload1 = "response1" + val responsePayload2 = "response2" + rsocket.setPayloadFluxToReturn( + Flux.just(DefaultPayload.create(responsePayload1), DefaultPayload.create(responsePayload2))) + val response = service.requestChannel(flowOf(requestPayload1, requestPayload2)).toList() + + assertThat(response).containsExactly(responsePayload1, responsePayload2) + assertThat(rsocket.savedMethodName).isEqualTo("requestChannel") + + val savedPayloads = rsocket.savedPayloadFlux + ?.asFlow() + ?.map { it.dataUtf8 } + ?.toList() + assertThat(savedPayloads).containsExactly(requestPayload1, requestPayload2) + } + + private interface SuspendingFunctionsService { + + @RSocketExchange("ff") + suspend fun fireAndForget(input: String) + + @RSocketExchange("rr") + suspend fun requestResponse(input: String): String + + @RSocketExchange("rs") + suspend fun requestStream(input: String): Flow + + @RSocketExchange("rc") + suspend fun requestChannel(input: Flow): Flow + } +} \ No newline at end of file